From c3aeb275c25a0623fbead69fa35ceca588651a82 Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Wed, 29 Jan 2020 19:58:14 +0100 Subject: [PATCH 01/25] add derive macros for weak & strong enums and structs --- sqlx-core/src/mysql/mod.rs | 2 +- sqlx-core/src/mysql/protocol/type.rs | 3 + sqlx-core/src/mysql/types/mod.rs | 9 +- sqlx-core/src/postgres/types/mod.rs | 4 + sqlx-macros/src/derives.rs | 909 ++++++++++++++++++++++++--- sqlx-macros/src/lib.rs | 26 +- src/lib.rs | 3 + src/types.rs | 6 + tests/derives.rs | 328 ++++++++-- 9 files changed, 1152 insertions(+), 138 deletions(-) create mode 100644 src/types.rs diff --git a/sqlx-core/src/mysql/mod.rs b/sqlx-core/src/mysql/mod.rs index 7efddc38..1cd9cd2a 100644 --- a/sqlx-core/src/mysql/mod.rs +++ b/sqlx-core/src/mysql/mod.rs @@ -29,4 +29,4 @@ pub type MySqlPool = crate::pool::Pool; make_query_as!(MySqlQueryAs, MySql, MySqlRow); impl_map_row_for_row!(MySql, MySqlRow); impl_column_index_for_row!(MySql); -impl_from_row_for_tuples!(MySql, MySqlRow); +impl_from_row_for_tuples!(MySql, MySqlRow); \ No newline at end of file diff --git a/sqlx-core/src/mysql/protocol/type.rs b/sqlx-core/src/mysql/protocol/type.rs index 38ecb453..792ccb57 100644 --- a/sqlx-core/src/mysql/protocol/type.rs +++ b/sqlx-core/src/mysql/protocol/type.rs @@ -39,6 +39,9 @@ type_id_consts! { pub const VAR_CHAR: TypeId = TypeId(253); // or VAR_BINARY pub const TEXT: TypeId = TypeId(252); // or BLOB + // Enum + pub const ENUM: TypeId = TypeId(247); + // More Bytes pub const TINY_BLOB: TypeId = TypeId(249); pub const MEDIUM_BLOB: TypeId = TypeId(250); diff --git a/sqlx-core/src/mysql/types/mod.rs b/sqlx-core/src/mysql/types/mod.rs index 3c1311b0..343d39aa 100644 --- a/sqlx-core/src/mysql/types/mod.rs +++ b/sqlx-core/src/mysql/types/mod.rs @@ -64,6 +64,11 @@ impl MySqlTypeInfo { _ => None, } } + + #[doc(hidden)] + pub fn r#enum() -> Self { + Self::new(TypeId::ENUM) + } } impl Display for MySqlTypeInfo { @@ -85,6 +90,7 @@ impl TypeInfo for MySqlTypeInfo { | TypeId::TINY_BLOB | TypeId::MEDIUM_BLOB | TypeId::LONG_BLOB + | TypeId::ENUM if (self.is_binary == other.is_binary) && match other.id { TypeId::VAR_CHAR @@ -92,7 +98,8 @@ impl TypeInfo for MySqlTypeInfo { | TypeId::CHAR | TypeId::TINY_BLOB | TypeId::MEDIUM_BLOB - | TypeId::LONG_BLOB => true, + | TypeId::LONG_BLOB + | TypeId::ENUM => true, _ => false, } => diff --git a/sqlx-core/src/postgres/types/mod.rs b/sqlx-core/src/postgres/types/mod.rs index f017ac7f..ab676e9a 100644 --- a/sqlx-core/src/postgres/types/mod.rs +++ b/sqlx-core/src/postgres/types/mod.rs @@ -57,6 +57,10 @@ impl PgTypeInfo { _ => None, } } + + pub fn oid(&self) -> u32 { + self.id.0 + } } impl Display for PgTypeInfo { diff --git a/sqlx-macros/src/derives.rs b/sqlx-macros/src/derives.rs index b4e73a07..fdba0742 100644 --- a/sqlx-macros/src/derives.rs +++ b/sqlx-macros/src/derives.rs @@ -1,108 +1,847 @@ +use proc_macro2::Ident; use quote::quote; -use syn::{parse_quote, Data, DataStruct, DeriveInput, Fields, FieldsUnnamed}; +use std::iter::FromIterator; +use syn::punctuated::Punctuated; +use syn::token::Comma; +use syn::{ + parse_quote, Arm, Attribute, Block, Data, DataEnum, DataStruct, DeriveInput, Expr, Field, + Fields, FieldsNamed, FieldsUnnamed, Lit, Meta, MetaNameValue, NestedMeta, Variant,Stmt, +}; -pub(crate) fn expand_derive_encode(input: DeriveInput) -> syn::Result { - match &input.data { - Data::Struct(DataStruct { - fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), - .. - }) if unnamed.len() == 1 => { - let ident = &input.ident; - let ty = &unnamed.first().unwrap().ty; +macro_rules! assert_attribute { + ($e:expr, $err:expr, $input:expr) => { + if !$e { + return Err(syn::Error::new_spanned($input, $err)); + } + }; +} - // extract type generics - let generics = &input.generics; - let (_, ty_generics, _) = generics.split_for_impl(); +struct SqlxAttributes { + transparent: bool, + postgres_oid: Option, + repr: Option, + rename: Option, +} - // add db type for impl generics & where clause - let mut generics = generics.clone(); - generics.params.insert(0, parse_quote!(DB: sqlx::Database)); - generics - .make_where_clause() - .predicates - .push(parse_quote!(#ty: sqlx::encode::Encode)); - let (impl_generics, _, where_clause) = generics.split_for_impl(); +fn parse_attributes(input: &[Attribute]) -> syn::Result { + let mut transparent = None; + let mut postgres_oid = None; + let mut repr = None; + let mut rename = None; - Ok(quote!( - impl #impl_generics sqlx::encode::Encode for #ident #ty_generics #where_clause { - fn encode(&self, buf: &mut ::RawBuffer) { - sqlx::encode::Encode::encode(&self.0, buf) - } - fn encode_nullable(&self, buf: &mut ::RawBuffer) -> sqlx::encode::IsNull { - sqlx::encode::Encode::encode_nullable(&self.0, buf) - } - fn size_hint(&self) -> usize { - sqlx::encode::Encode::size_hint(&self.0) + macro_rules! fail { + ($t:expr, $m:expr) => { + return Err(syn::Error::new_spanned($t, $m)); + }; + } + + macro_rules! try_set { + ($i:ident, $v:expr, $t:expr) => { + match $i { + None => $i = Some($v), + Some(_) => fail!($t, "duplicate attribute"), + } + }; + } + + for attr in input { + let meta = attr + .parse_meta() + .map_err(|e| syn::Error::new_spanned(attr, e))?; + match meta { + Meta::List(list) if list.path.is_ident("sqlx") => { + for value in list.nested.iter() { + match value { + NestedMeta::Meta(meta) => match meta { + Meta::Path(p) if p.is_ident("transparent") => { + try_set!(transparent, true, value) + } + Meta::NameValue(MetaNameValue { + path, + lit: Lit::Str(val), + .. + }) if path.is_ident("rename") => try_set!(rename, val.value(), value), + Meta::List(list) if list.path.is_ident("postgres") => { + for value in list.nested.iter() { + match value { + NestedMeta::Meta(Meta::NameValue(MetaNameValue { + path, + lit: Lit::Int(val), + .. + })) if path.is_ident("oid") => { + try_set!(postgres_oid, val.base10_parse()?, value); + } + u => fail!(u, "unexpected value"), + } + } + } + + u => fail!(u, "unexpected attribute"), + }, + u => fail!(u, "unexpected attribute"), } } - )) + } + Meta::List(list) if list.path.is_ident("repr") => { + if list.nested.len() != 1 { + fail!(&list.nested, "expected one value") + } + match list.nested.first().unwrap() { + NestedMeta::Meta(Meta::Path(p)) if p.get_ident().is_some() => { + try_set!(repr, p.get_ident().unwrap().clone(), list); + } + u => fail!(u, "unexpected value"), + } + } + _ => {} } - _ => Err(syn::Error::new_spanned( - input, - "expected a tuple struct with a single field", - )), } + + Ok(SqlxAttributes { + transparent: transparent.unwrap_or(false), + postgres_oid, + repr, + rename, + }) } -pub(crate) fn expand_derive_decode(input: DeriveInput) -> syn::Result { +fn check_transparent_attributes(input: &DeriveInput, field: &Field) -> syn::Result<()> { + let attributes = parse_attributes(&input.attrs)?; + assert_attribute!( + attributes.transparent, + "expected #[sqlx(transparent)]", + input + ); + #[cfg(feature = "postgres")] + assert_attribute!( + attributes.postgres_oid.is_none(), + "unexpected #[sqlx(postgres(oid = ..))]", + input + ); + assert_attribute!( + attributes.rename.is_none(), + "unexpected #[sqlx(rename = ..)]", + field + ); + assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input); + let attributes = parse_attributes(&field.attrs)?; + assert_attribute!( + !attributes.transparent, + "unexpected #[sqlx(transparent)]", + field + ); + #[cfg(feature = "postgres")] + assert_attribute!( + attributes.postgres_oid.is_none(), + "unexpected #[sqlx(postgres(oid = ..))]", + field + ); + assert_attribute!( + attributes.rename.is_none(), + "unexpected #[sqlx(rename = ..)]", + field + ); + assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", field); + Ok(()) +} + +fn check_enum_attributes<'a>( + input: &'a DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let attributes = parse_attributes(&input.attrs)?; + assert_attribute!( + !attributes.transparent, + "unexpected #[sqlx(transparent)]", + input + ); + assert_attribute!( + attributes.rename.is_none(), + "unexpected #[sqlx(rename = ..)]", + input + ); + + for variant in variants { + let attributes = parse_attributes(&variant.attrs)?; + assert_attribute!( + !attributes.transparent, + "unexpected #[sqlx(transparent)]", + variant + ); + #[cfg(feature = "postgres")] + assert_attribute!( + attributes.postgres_oid.is_none(), + "unexpected #[sqlx(postgres(oid = ..))]", + variant + ); + assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", variant); + } + + Ok(attributes) +} + +fn check_weak_enum_attributes( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let attributes = check_enum_attributes(input, variants)?; + #[cfg(feature = "postgres")] + assert_attribute!( + attributes.postgres_oid.is_none(), + "unexpected #[sqlx(postgres(oid = ..))]", + input + ); + assert_attribute!(attributes.repr.is_some(), "expected #[repr(..)]", input); + for variant in variants { + let attributes = parse_attributes(&variant.attrs)?; + assert_attribute!( + attributes.rename.is_none(), + "unexpected #[sqlx(rename = ..)]", + variant + ); + } + Ok(attributes.repr.unwrap()) +} + +fn check_strong_enum_attributes( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let attributes = check_enum_attributes(input, variants)?; + #[cfg(feature = "postgres")] + assert_attribute!( + attributes.postgres_oid.is_some(), + "expected #[sqlx(postgres(oid = ..))]", + input + ); + assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input); + Ok(attributes) +} + +fn check_struct_attributes<'a>( + input: &'a DeriveInput, + fields: &Punctuated, +) -> syn::Result { + let attributes = parse_attributes(&input.attrs)?; + assert_attribute!( + !attributes.transparent, + "unexpected #[sqlx(transparent)]", + input + ); + #[cfg(feature = "postgres")] + assert_attribute!( + attributes.postgres_oid.is_some(), + "expected #[sqlx(postgres(oid = ..))]", + input + ); + assert_attribute!( + attributes.rename.is_none(), + "unexpected #[sqlx(rename = ..)]", + input + ); + assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input); + + for field in fields { + let attributes = parse_attributes(&field.attrs)?; + assert_attribute!( + !attributes.transparent, + "unexpected #[sqlx(transparent)]", + field + ); + #[cfg(feature = "postgres")] + assert_attribute!( + attributes.postgres_oid.is_none(), + "unexpected #[sqlx(postgres(oid = ..))]", + field + ); + assert_attribute!( + attributes.rename.is_none(), + "unexpected #[sqlx(rename = ..)]", + field + ); + assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", field); + } + + Ok(attributes) +} + +pub(crate) fn expand_derive_encode(input: &DeriveInput) -> syn::Result { + let args = parse_attributes(&input.attrs)?; + match &input.data { Data::Struct(DataStruct { fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), .. }) if unnamed.len() == 1 => { - let ident = &input.ident; - let ty = &unnamed.first().unwrap().ty; - - // extract type generics - let generics = &input.generics; - let (_, ty_generics, _) = generics.split_for_impl(); - - let mut impls = Vec::new(); - - if cfg!(feature = "postgres") { - let mut generics = generics.clone(); - generics.params.insert(0, parse_quote!('de)); - generics - .make_where_clause() - .predicates - .push(parse_quote!(#ty: sqlx::decode::Decode<'de, sqlx::Postgres>)); - - let (impl_generics, _, where_clause) = generics.split_for_impl(); - - impls.push(quote!( - impl #impl_generics sqlx::decode::Decode<'de, sqlx::Postgres> for #ident #ty_generics #where_clause { - fn decode(value: >::RawValue) -> sqlx::Result { - <#ty as sqlx::decode::Decode<'de, sqlx::Postgres>>::decode(value).map(Self) - } - } - )); - } - - if cfg!(feature = "mysql") { - let mut generics = generics.clone(); - generics.params.insert(0, parse_quote!('de)); - generics - .make_where_clause() - .predicates - .push(parse_quote!(#ty: sqlx::decode::Decode<'de, sqlx::MySql>)); - - let (impl_generics, _, where_clause) = generics.split_for_impl(); - - impls.push(quote!( - impl #impl_generics sqlx::decode::Decode<'de, sqlx::MySql> for #ident #ty_generics #where_clause { - fn decode(value: >::RawValue) -> sqlx::Result { - <#ty as sqlx::decode::Decode<'de, sqlx::MySql>>::decode(value).map(Self) - } - } - )); - } - - // panic!("{}", q) - Ok(quote!(#(#impls)*)) + expand_derive_encode_transparent(&input, unnamed.first().unwrap()) } + Data::Enum(DataEnum { variants, .. }) => match args.repr { + Some(_) => expand_derive_encode_weak_enum(input, variants), + None => expand_derive_encode_strong_enum(input, variants), + }, + Data::Struct(DataStruct { + fields: Fields::Named(FieldsNamed { named, .. }), + .. + }) => expand_derive_encode_struct(input, named), _ => Err(syn::Error::new_spanned( input, "expected a tuple struct with a single field", )), } } + +fn expand_derive_encode_transparent( + input: &DeriveInput, + field: &Field, +) -> syn::Result { + check_transparent_attributes(input, field)?; + + let ident = &input.ident; + let ty = &field.ty; + + // extract type generics + let generics = &input.generics; + let (_, ty_generics, _) = generics.split_for_impl(); + + // add db type for impl generics & where clause + let mut generics = generics.clone(); + generics.params.insert(0, parse_quote!(DB: sqlx::Database)); + generics + .make_where_clause() + .predicates + .push(parse_quote!(#ty: sqlx::encode::Encode)); + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + Ok(quote!( + impl #impl_generics sqlx::encode::Encode for #ident #ty_generics #where_clause { + fn encode(&self, buf: &mut std::vec::Vec) { + sqlx::encode::Encode::encode(&self.0, buf) + } + fn encode_nullable(&self, buf: &mut std::vec::Vec) -> sqlx::encode::IsNull { + sqlx::encode::Encode::encode_nullable(&self.0, buf) + } + fn size_hint(&self) -> usize { + sqlx::encode::Encode::size_hint(&self.0) + } + } + )) +} + +fn expand_derive_encode_weak_enum( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let repr = check_weak_enum_attributes(input, &variants)?; + + let ident = &input.ident; + + Ok(quote!( + impl sqlx::encode::Encode for #ident where #repr: sqlx::encode::Encode { + fn encode(&self, buf: &mut std::vec::Vec) { + sqlx::encode::Encode::encode(&(*self as #repr), buf) + } + fn encode_nullable(&self, buf: &mut std::vec::Vec) -> sqlx::encode::IsNull { + sqlx::encode::Encode::encode_nullable(&(*self as #repr), buf) + } + fn size_hint(&self) -> usize { + sqlx::encode::Encode::size_hint(&(*self as #repr)) + } + } + )) +} + +fn expand_derive_encode_strong_enum( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + check_strong_enum_attributes(input, &variants)?; + + let ident = &input.ident; + + let mut tts = proc_macro2::TokenStream::new(); + + if cfg!(feature = "mysql") { + let mut value_arms = Vec::new(); + for v in variants { + let id = &v.ident; + let attributes = parse_attributes(&v.attrs)?; + if let Some(rename) = attributes.rename { + value_arms.push(quote!(#ident :: #id => #rename,)); + } else { + let name = id.to_string(); + value_arms.push(quote!(#ident :: #id => #name,)); + } + } + + tts.extend(quote!( + impl sqlx::encode::Encode for #ident where str: sqlx::encode::Encode { + fn encode(&self, buf: &mut std::vec::Vec) { + let val = match self { + #(#value_arms)* + }; + >::encode(val, buf) + } + fn size_hint(&self) -> usize { + let val = match self { + #(#value_arms)* + }; + >::size_hint(val) + } + } + )); + } + + Ok(tts) +} + +fn expand_derive_encode_struct( + input: &DeriveInput, + fields: &Punctuated, +) -> syn::Result { + check_struct_attributes(input, &fields)?; + + let mut tts = proc_macro2::TokenStream::new(); + + if cfg!(feature = "postgres") { + let ident = &input.ident; + + let column_count = fields.len(); + + // extract type generics + let generics = &input.generics; + let (_, ty_generics, _) = generics.split_for_impl(); + + // add db type for impl generics & where clause + let mut generics = generics.clone(); + let predicates = &mut generics.make_where_clause().predicates; + for field in fields { + let ty = &field.ty; + predicates.push(parse_quote!(#ty: sqlx::encode::Encode)); + predicates.push(parse_quote!(sqlx::Postgres: sqlx::types::HasSqlType<#ty>)); + } + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + let mut writes: Vec = Vec::new(); + for field in fields { + let id = &field.ident; + let ty = &field.ty; + writes.push(parse_quote!({ + // write oid + let info = >::type_info(); + buf.extend(&info.oid().to_be_bytes()); + + // write zeros for length + buf.extend(&[0; 4]); + + let start = buf.len(); + sqlx::encode::Encode::::encode(&self. #id, buf); + let end = buf.len(); + let size = end - start; + + // replaces zeros with actual length + buf[start-4..start].copy_from_slice(&(size as u32).to_be_bytes()); + })); + } + + let mut sizes: Vec = Vec::new(); + for field in fields { + let id = &field.ident; + let ty = &field.ty; + sizes.push( + parse_quote!(<#ty as sqlx::encode::Encode>::size_hint(&self. #id)), + ); + } + + tts.extend(quote!( + impl #impl_generics sqlx::encode::Encode for #ident #ty_generics #where_clause { + fn encode(&self, buf: &mut std::vec::Vec) { + buf.extend(&(#column_count as u32).to_be_bytes()); + #(#writes)* + } + fn size_hint(&self) -> usize { + 4 + #column_count * (4 + 4) + #(#sizes)+* + } + } + )); + } + + Ok(tts) +} + +pub(crate) fn expand_derive_decode(input: &DeriveInput) -> syn::Result { + let attrs = parse_attributes(&input.attrs)?; + match &input.data { + Data::Struct(DataStruct { + fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), + .. + }) if unnamed.len() == 1 => { + expand_derive_decode_transparent(input, unnamed.first().unwrap()) + } + Data::Enum(DataEnum { variants, .. }) => match attrs.repr { + Some(_) => expand_derive_decode_weak_enum(input, variants), + None => expand_derive_decode_strong_enum(input, variants), + }, + Data::Struct(DataStruct { + fields: Fields::Named(FieldsNamed { named, .. }), + .. + }) => expand_derive_decode_struct(input, named), + _ => Err(syn::Error::new_spanned( + input, + "expected a tuple struct with a single field", + )), + } +} + +fn expand_derive_decode_transparent( + input: &DeriveInput, + field: &Field, +) -> syn::Result { + check_transparent_attributes(input, field)?; + + let ident = &input.ident; + let ty = &field.ty; + + // extract type generics + let generics = &input.generics; + let (_, ty_generics, _) = generics.split_for_impl(); + + // add db type for impl generics & where clause + let mut generics = generics.clone(); + generics.params.insert(0, parse_quote!(DB: sqlx::Database)); + generics + .make_where_clause() + .predicates + .push(parse_quote!(#ty: sqlx::decode::Decode)); + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + Ok(quote!( + impl #impl_generics sqlx::decode::Decode for #ident #ty_generics #where_clause { + fn decode(raw: &[u8]) -> std::result::Result { + <#ty as sqlx::decode::Decode>::decode(raw).map(Self) + } + fn decode_null() -> std::result::Result { + <#ty as sqlx::decode::Decode>::decode_null().map(Self) + } + fn decode_nullable(raw: std::option::Option<&[u8]>) -> std::result::Result { + <#ty as sqlx::decode::Decode>::decode_nullable(raw).map(Self) + } + } + )) +} + +fn expand_derive_decode_weak_enum( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let repr = check_weak_enum_attributes(input, &variants)?; + + let ident = &input.ident; + let arms = variants + .iter() + .map(|v| { + let id = &v.ident; + parse_quote!(_ if (#ident :: #id as #repr) == val => Ok(#ident :: #id),) + }) + .collect::>(); + + Ok(quote!( + impl sqlx::decode::Decode for #ident where #repr: sqlx::decode::Decode { + fn decode(raw: &[u8]) -> std::result::Result { + let val = <#repr as sqlx::decode::Decode>::decode(raw)?; + match val { + #(#arms)* + _ => Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid value"))) + } + } + } + )) +} + +fn expand_derive_decode_strong_enum( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + check_strong_enum_attributes(input, &variants)?; + + let ident = &input.ident; + + let mut value_arms = Vec::new(); + for v in variants { + let id = &v.ident; + let attributes = parse_attributes(&v.attrs)?; + if let Some(rename) = attributes.rename { + value_arms.push(quote!(#rename => Ok(#ident :: #id),)); + } else { + let name = id.to_string(); + value_arms.push(quote!(#name => Ok(#ident :: #id),)); + } + } + + // TODO: prevent heap allocation + Ok(quote!( + impl sqlx::decode::Decode for #ident where String: sqlx::decode::Decode { + fn decode(buf: &[u8]) -> std::result::Result { + let val = >::decode(buf)?; + match val.as_str() { + #(#value_arms)* + _ => Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid value"))) + } + } + } + )) +} + +fn expand_derive_decode_struct( + input: &DeriveInput, + fields: &Punctuated, +) -> syn::Result { + check_struct_attributes(input, fields)?; + + let ident = &input.ident; + + let column_count = fields.len(); + + // extract type generics + let generics = &input.generics; + let (_, ty_generics, _) = generics.split_for_impl(); + + // add db type for impl generics & where clause + let mut generics = generics.clone(); + let predicates = &mut generics.make_where_clause().predicates; + for field in fields { + let ty = &field.ty; + predicates.push(parse_quote!(#ty: sqlx::decode::Decode)); + predicates.push(parse_quote!(sqlx::Postgres: sqlx::types::HasSqlType<#ty>)); + } + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + let mut reads: Vec> = Vec::new(); + let mut names: Vec = Vec::new(); + for field in fields { + let id = &field.ident; + names.push(id.clone().unwrap()); + let ty = &field.ty; + reads.push(parse_quote!( + if buf.len() < 8 { + return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Not enough data sent"))); + } + + let oid = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[0..4]).unwrap()); + if oid != >::type_info().oid() { + return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid oid"))); + } + + let len = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[4..8]).unwrap()) as usize; + + if buf.len() < 8 + len { + return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Not enough data sent"))); + } + + let raw = &buf[8..8+len]; + let #id = <#ty as sqlx::decode::Decode>::decode(raw)?; + + let buf = &buf[8+len..]; + )); + } + let reads = reads.into_iter().flatten(); + + Ok(quote!( + impl #impl_generics sqlx::decode::Decode for #ident#ty_generics #where_clause { + fn decode(buf: &[u8]) -> std::result::Result { + if buf.len() < 4 { + return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Not enough data sent"))); + } + + let column_count = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[..4]).unwrap()) as usize; + if column_count != #column_count { + return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid column count"))); + } + let buf = &buf[4..]; + + #(#reads)* + + if !buf.is_empty() { + return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new(format!("Too much data sent ({} bytes left)", buf.len())))); + } + + Ok(#ident { + #(#names),* + }) + } + } + )) +} + +pub(crate) fn expand_derive_has_sql_type( + input: &DeriveInput, +) -> syn::Result { + let attrs = parse_attributes(&input.attrs)?; + match &input.data { + Data::Struct(DataStruct { + fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), + .. + }) if unnamed.len() == 1 => { + expand_derive_has_sql_type_transparent(input, unnamed.first().unwrap()) + } + Data::Enum(DataEnum { variants, .. }) => match attrs.repr { + Some(_) => expand_derive_has_sql_type_weak_enum(input, variants), + None => expand_derive_has_sql_type_strong_enum(input, variants), + }, + Data::Struct(DataStruct { + fields: Fields::Named(FieldsNamed { named, .. }), + .. + }) => expand_derive_has_sql_type_struct(input, named), + _ => Err(syn::Error::new_spanned( + input, + "expected a tuple struct with a single field", + )), + } +} + +fn expand_derive_has_sql_type_transparent( + input: &DeriveInput, + field: &Field, +) -> syn::Result { + check_transparent_attributes(input, field)?; + + let ident = &input.ident; + let ty = &field.ty; + + // extract type generics + let generics = &input.generics; + let (impl_generics, ty_generics, _) = generics.split_for_impl(); + + // add db type for clause + let mut generics = generics.clone(); + generics + .make_where_clause() + .predicates + .push(parse_quote!(Self: sqlx::types::HasSqlType<#ty>)); + let (_, _, where_clause) = generics.split_for_impl(); + + let mut tts = proc_macro2::TokenStream::new(); + + if cfg!(feature = "mysql") { + tts.extend(quote!( + impl #impl_generics sqlx::types::HasSqlType< #ident #ty_generics > for sqlx::MySql #where_clause { + fn type_info() -> Self::TypeInfo { + >::type_info() + } + } + )); + } + + if cfg!(feature = "postgres") { + tts.extend(quote!( + impl #impl_generics sqlx::types::HasSqlType< #ident #ty_generics > for sqlx::Postgres #where_clause { + fn type_info() -> Self::TypeInfo { + >::type_info() + } + } + )); + } + + Ok(tts) +} + +fn expand_derive_has_sql_type_weak_enum( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let repr = check_weak_enum_attributes(input, variants)?; + + let ident = &input.ident; + let mut tts = proc_macro2::TokenStream::new(); + + if cfg!(feature = "mysql") { + tts.extend(quote!( + impl sqlx::types::HasSqlType< #ident > for sqlx::MySql where Self: sqlx::types::HasSqlType< #repr > { + fn type_info() -> Self::TypeInfo { + >::type_info() + } + } + )); + } + + if cfg!(feature = "postgres") { + tts.extend(quote!( + impl sqlx::types::HasSqlType< #ident > for sqlx::Postgres where Self: sqlx::types::HasSqlType< #repr > { + fn type_info() -> Self::TypeInfo { + >::type_info() + } + } + )); + } + + Ok(tts) +} + +fn expand_derive_has_sql_type_strong_enum( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let attributes = check_strong_enum_attributes(input, variants)?; + + let ident = &input.ident; + let mut tts = proc_macro2::TokenStream::new(); + + if cfg!(feature = "mysql") { + tts.extend(quote!( + impl sqlx::types::HasSqlType< #ident > for sqlx::MySql { + fn type_info() -> Self::TypeInfo { + sqlx::mysql::MySqlTypeInfo::r#enum() + } + } + )); + } + + if cfg!(feature = "postgres") { + let oid = attributes.postgres_oid.unwrap(); + tts.extend(quote!( + impl sqlx::types::HasSqlType< #ident > for sqlx::Postgres { + fn type_info() -> Self::TypeInfo { + sqlx::postgres::PgTypeInfo::with_oid(#oid) + } + } + )); + } + + Ok(tts) +} + +fn expand_derive_has_sql_type_struct( + input: &DeriveInput, + fields: &Punctuated, +) -> syn::Result { + let attributes = check_struct_attributes(input, fields)?; + + let ident = &input.ident; + let mut tts = proc_macro2::TokenStream::new(); + + if cfg!(feature = "postgres") { + let oid = attributes.postgres_oid.unwrap(); + tts.extend(quote!( + impl sqlx::types::HasSqlType< #ident > for sqlx::Postgres { + fn type_info() -> Self::TypeInfo { + sqlx::postgres::PgTypeInfo::with_oid(#oid) + } + } + )); + } + + Ok(tts) +} + +pub(crate) fn expand_derive_type(input: &DeriveInput) -> syn::Result { + let encode_tts = expand_derive_encode(input)?; + let decode_tts = expand_derive_decode(input)?; + let has_sql_type_tts = expand_derive_has_sql_type(input)?; + + let combined = proc_macro2::TokenStream::from_iter( + encode_tts + .into_iter() + .chain(decode_tts) + .chain(has_sql_type_tts), + ); + Ok(combined) +} diff --git a/sqlx-macros/src/lib.rs b/sqlx-macros/src/lib.rs index 0bed307f..f74c775e 100644 --- a/sqlx-macros/src/lib.rs +++ b/sqlx-macros/src/lib.rs @@ -151,19 +151,37 @@ pub fn query_file_as(input: TokenStream) -> TokenStream { async_macro!(db, input: QueryAsMacroInput => expand_query_file_as(input, db)) } -#[proc_macro_derive(Encode)] +#[proc_macro_derive(Encode, attributes(sqlx))] pub fn derive_encode(tokenstream: TokenStream) -> TokenStream { let input = syn::parse_macro_input!(tokenstream as syn::DeriveInput); - match derives::expand_derive_encode(input) { + match derives::expand_derive_encode(&input) { Ok(ts) => ts.into(), Err(e) => e.to_compile_error().into(), } } -#[proc_macro_derive(Decode)] +#[proc_macro_derive(Decode, attributes(sqlx))] pub fn derive_decode(tokenstream: TokenStream) -> TokenStream { let input = syn::parse_macro_input!(tokenstream as syn::DeriveInput); - match derives::expand_derive_decode(input) { + match derives::expand_derive_decode(&input) { + Ok(ts) => ts.into(), + Err(e) => e.to_compile_error().into(), + } +} + +#[proc_macro_derive(HasSqlType, attributes(sqlx))] +pub fn derive_has_sql_type(tokenstream: TokenStream) -> TokenStream { + let input = syn::parse_macro_input!(tokenstream as syn::DeriveInput); + match derives::expand_derive_has_sql_type(&input) { + Ok(ts) => ts.into(), + Err(e) => e.to_compile_error().into(), + } +} + +#[proc_macro_derive(Type, attributes(sqlx))] +pub fn derive_type(tokenstream: TokenStream) -> TokenStream { + let input = syn::parse_macro_input!(tokenstream as syn::DeriveInput); + match derives::expand_derive_type(&input) { Ok(ts) => ts.into(), Err(e) => e.to_compile_error().into(), } diff --git a/src/lib.rs b/src/lib.rs index d1dc6aaa..3342b09b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -40,6 +40,9 @@ pub use sqlx_core::sqlite::{self, Sqlite, SqliteConnection, SqlitePool}; #[doc(hidden)] pub extern crate sqlx_macros; +#[cfg(feature = "macros")] +pub use sqlx_macros::Type; + #[cfg(feature = "macros")] mod macros; diff --git a/src/types.rs b/src/types.rs new file mode 100644 index 00000000..6eb3cab6 --- /dev/null +++ b/src/types.rs @@ -0,0 +1,6 @@ +//! Traits linking Rust types to SQL types. + +pub use sqlx_core::types::*; + +#[cfg(feature = "macros")] +pub use sqlx_macros::HasSqlType; diff --git a/tests/derives.rs b/tests/derives.rs index d5120031..6ec2cc90 100644 --- a/tests/derives.rs +++ b/tests/derives.rs @@ -1,77 +1,311 @@ use sqlx::decode::Decode; use sqlx::encode::Encode; +use sqlx::types::{HasSqlType, TypeInfo}; +use std::fmt::Debug; -#[derive(PartialEq, Debug, Encode, Decode)] -struct Foo(i32); +#[derive(PartialEq, Debug, Encode, Decode, HasSqlType)] +#[sqlx(transparent)] +struct Transparent(i32); -#[test] -#[cfg(feature = "postgres")] -fn encode_with_postgres() { - use sqlx_core::postgres::Postgres; +#[derive(PartialEq, Debug, Clone, Copy, Encode, Decode, HasSqlType)] +#[repr(i32)] +#[allow(dead_code)] +enum Weak { + One, + Two, + Three, +} - let example = Foo(0x1122_3344); +#[derive(PartialEq, Debug, Encode, Decode, HasSqlType)] +#[sqlx(postgres(oid = 10101010))] +#[allow(dead_code)] +enum Strong { + One, + Two, + #[sqlx(rename = "four")] + Three, +} - let mut encoded = Vec::new(); - let mut encoded_orig = Vec::new(); - - Encode::::encode(&example, &mut encoded); - Encode::::encode(&example.0, &mut encoded_orig); - - assert_eq!(encoded, encoded_orig); +#[derive(PartialEq, Debug, Encode, Decode, HasSqlType)] +#[sqlx(postgres(oid = 20202020))] +#[allow(dead_code)] +struct Struct { + field1: String, + field2: i64, + field3: bool, } #[test] #[cfg(feature = "mysql")] -fn encode_with_mysql() { - use sqlx_core::mysql::MySql; - - let example = Foo(0x1122_3344); - - let mut encoded = Vec::new(); - let mut encoded_orig = Vec::new(); - - Encode::::encode(&example, &mut encoded); - Encode::::encode(&example.0, &mut encoded_orig); - - assert_eq!(encoded, encoded_orig); -} - -#[test] -#[cfg(feature = "mysql")] -fn decode_mysql() { - decode_with_db(); +fn encode_transparent_mysql() { + encode_transparent::(); } #[test] #[cfg(feature = "postgres")] -fn decode_postgres() { - decode_with_db(); +fn encode_transparent_postgres() { + encode_transparent::(); } -#[cfg(feature = "postgres")] -fn decode_with_db() +#[allow(dead_code)] +fn encode_transparent() where - Foo: for<'de> Decode<'de, sqlx::Postgres> + Encode, + Transparent: Encode, + i32: Encode, { - let example = Foo(0x1122_3344); + let example = Transparent(0x1122_3344); + + let mut encoded = Vec::new(); + let mut encoded_orig = Vec::new(); + + Encode::::encode(&example, &mut encoded); + Encode::::encode(&example.0, &mut encoded_orig); + + assert_eq!(encoded, encoded_orig); +} + +#[test] +#[cfg(feature = "mysql")] +fn encode_weak_enum_mysql() { + encode_weak_enum::(); +} + +#[test] +#[cfg(feature = "postgres")] +fn encode_weak_enum_postgres() { + encode_weak_enum::(); +} + +#[allow(dead_code)] +fn encode_weak_enum() +where + Weak: Encode, + i32: Encode, +{ + for example in [Weak::One, Weak::Two, Weak::Three].iter() { + let mut encoded = Vec::new(); + let mut encoded_orig = Vec::new(); + + Encode::::encode(example, &mut encoded); + Encode::::encode(&(*example as i32), &mut encoded_orig); + + assert_eq!(encoded, encoded_orig); + } +} + +#[test] +#[cfg(feature = "mysql")] +fn encode_strong_enum_mysql() { + encode_strong_enum::(); +} + +#[test] +#[cfg(feature = "postgres")] +fn encode_strong_enum_postgres() { + encode_strong_enum::(); +} + +#[allow(dead_code)] +fn encode_strong_enum() +where + Strong: Encode, + str: Encode, +{ + for (example, name) in [ + (Strong::One, "One"), + (Strong::Two, "Two"), + (Strong::Three, "four"), + ] + .iter() + { + let mut encoded = Vec::new(); + let mut encoded_orig = Vec::new(); + + Encode::::encode(example, &mut encoded); + Encode::::encode(name, &mut encoded_orig); + + assert_eq!(encoded, encoded_orig); + } +} + +#[test] +#[cfg(feature = "postgres")] +fn encode_struct_postgres() { + let field1 = "Foo".to_string(); + let field2 = 3; + let field3 = false; + + let example = Struct { + field1: field1.clone(), + field2, + field3, + }; let mut encoded = Vec::new(); Encode::::encode(&example, &mut encoded); - let decoded = Foo::decode(Some(sqlx::postgres::PgValue::Binary(&encoded))).unwrap(); - assert_eq!(example, decoded); + let string_oid = >::type_info().oid(); + let i64_oid = >::type_info().oid(); + let bool_oid = >::type_info().oid(); + + // 3 columns + assert_eq!(&[0, 0, 0, 3], &encoded[..4]); + let encoded = &encoded[4..]; + + // check field1 (string) + assert_eq!(&string_oid.to_be_bytes(), &encoded[0..4]); + assert_eq!(&(field1.len() as u32).to_be_bytes(), &encoded[4..8]); + assert_eq!(field1.as_bytes(), &encoded[8..8 + field1.len()]); + let encoded = &encoded[8 + field1.len()..]; + + // check field2 (i64) + assert_eq!(&i64_oid.to_be_bytes(), &encoded[0..4]); + assert_eq!(&8u32.to_be_bytes(), &encoded[4..8]); + assert_eq!(field2.to_be_bytes(), &encoded[8..16]); + let encoded = &encoded[16..]; + + // check field3 (bool) + assert_eq!(&bool_oid.to_be_bytes(), &encoded[0..4]); + assert_eq!(&1u32.to_be_bytes(), &encoded[4..8]); + assert_eq!(field3, encoded[8] != 0); + let encoded = &encoded[9..]; + + assert!(encoded.is_empty()); + + let string_size = >::size_hint(&field1); + let i64_size = >::size_hint(&field2); + let bool_size = >::size_hint(&field3); + + assert_eq!( + 4 + 3 * (4 + 4) + string_size + i64_size + bool_size, + example.size_hint() + ); } +#[test] #[cfg(feature = "mysql")] -fn decode_with_db() -where - Foo: for<'de> Decode<'de, sqlx::MySql> + Encode, -{ - let example = Foo(0x1122_3344); +fn decode_transparent_mysql() { + decode_with_db::(Transparent(0x1122_3344)); +} +#[test] +#[cfg(feature = "postgres")] +fn decode_transparent_postgres() { + decode_with_db::(Transparent(0x1122_3344)); +} + +#[test] +#[cfg(feature = "mysql")] +fn decode_weak_enum_mysql() { + decode_with_db::(Weak::One); + decode_with_db::(Weak::Two); + decode_with_db::(Weak::Three); +} + +#[test] +#[cfg(feature = "postgres")] +fn decode_weak_enum_postgres() { + decode_with_db::(Weak::One); + decode_with_db::(Weak::Two); + decode_with_db::(Weak::Three); +} + +#[test] +#[cfg(feature = "mysql")] +fn decode_strong_enum_mysql() { + decode_with_db::(Strong::One); + decode_with_db::(Strong::Two); + decode_with_db::(Strong::Three); +} + +#[test] +#[cfg(feature = "postgres")] +fn decode_strong_enum_postgres() { + decode_with_db::(Strong::One); + decode_with_db::(Strong::Two); + decode_with_db::(Strong::Three); +} + +#[test] +#[cfg(feature = "postgres")] +fn decode_struct_postgres() { + decode_with_db::(Struct { + field1: "Foo".to_string(), + field2: 3, + field3: true, + }); +} + +#[allow(dead_code)] +fn decode_with_db + Encode + PartialEq + Debug>(example: V) { let mut encoded = Vec::new(); - Encode::::encode(&example, &mut encoded); + Encode::::encode(&example, &mut encoded); - let decoded = Foo::decode(Some(sqlx::mysql::MySqlValue::Binary(&encoded))).unwrap(); + let decoded = V::decode(&encoded).unwrap(); assert_eq!(example, decoded); } + +#[test] +#[cfg(feature = "mysql")] +fn has_sql_type_transparent_mysql() { + has_sql_type_transparent::(); +} + +#[test] +#[cfg(feature = "postgres")] +fn has_sql_type_transparent_postgres() { + has_sql_type_transparent::(); +} + +#[allow(dead_code)] +fn has_sql_type_transparent() +where + DB: HasSqlType + HasSqlType, +{ + let info: DB::TypeInfo = >::type_info(); + let info_orig: DB::TypeInfo = >::type_info(); + assert!(info.compatible(&info_orig)); +} + +#[test] +#[cfg(feature = "mysql")] +fn has_sql_type_weak_enum_mysql() { + has_sql_type_weak_enum::(); +} + +#[test] +#[cfg(feature = "postgres")] +fn has_sql_type_weak_enum_postgres() { + has_sql_type_weak_enum::(); +} + +#[allow(dead_code)] +fn has_sql_type_weak_enum() +where + DB: HasSqlType + HasSqlType, +{ + let info: DB::TypeInfo = >::type_info(); + let info_orig: DB::TypeInfo = >::type_info(); + assert!(info.compatible(&info_orig)); +} + +#[test] +#[cfg(feature = "mysql")] +fn has_sql_type_strong_enum_mysql() { + let info: sqlx::mysql::MySqlTypeInfo = >::type_info(); + assert!(info.compatible(&sqlx::mysql::MySqlTypeInfo::r#enum())) +} + +#[test] +#[cfg(feature = "postgres")] +fn has_sql_type_strong_enum_postgres() { + let info: sqlx::postgres::PgTypeInfo = >::type_info(); + assert!(info.compatible(&sqlx::postgres::PgTypeInfo::with_oid(10101010))) +} + +#[test] +#[cfg(feature = "postgres")] +fn has_sql_type_struct_postgres() { + let info: sqlx::postgres::PgTypeInfo = >::type_info(); + assert!(info.compatible(&sqlx::postgres::PgTypeInfo::with_oid(20202020))) +} From e166f062be7327bba7c455f54fe22b8731f99670 Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Wed, 29 Jan 2020 19:59:09 +0100 Subject: [PATCH 02/25] format --- sqlx-macros/src/derives.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlx-macros/src/derives.rs b/sqlx-macros/src/derives.rs index fdba0742..5716eb8a 100644 --- a/sqlx-macros/src/derives.rs +++ b/sqlx-macros/src/derives.rs @@ -5,7 +5,7 @@ use syn::punctuated::Punctuated; use syn::token::Comma; use syn::{ parse_quote, Arm, Attribute, Block, Data, DataEnum, DataStruct, DeriveInput, Expr, Field, - Fields, FieldsNamed, FieldsUnnamed, Lit, Meta, MetaNameValue, NestedMeta, Variant,Stmt, + Fields, FieldsNamed, FieldsUnnamed, Lit, Meta, MetaNameValue, NestedMeta, Stmt, Variant, }; macro_rules! assert_attribute { From d3cb84b893795f1e746336c38b262b553dbae2e4 Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Wed, 29 Jan 2020 21:05:22 +0100 Subject: [PATCH 03/25] fix db type --- tests/derives.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/derives.rs b/tests/derives.rs index 6ec2cc90..4c22324e 100644 --- a/tests/derives.rs +++ b/tests/derives.rs @@ -121,8 +121,8 @@ where let mut encoded = Vec::new(); let mut encoded_orig = Vec::new(); - Encode::::encode(example, &mut encoded); - Encode::::encode(name, &mut encoded_orig); + Encode::::encode(example, &mut encoded); + Encode::::encode(*name, &mut encoded_orig); assert_eq!(encoded, encoded_orig); } From 9c96bc92ee1fbdc38220e3b959a6b899fea1c3ba Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Wed, 29 Jan 2020 21:05:52 +0100 Subject: [PATCH 04/25] move feature guard from strong_enum to struct --- sqlx-macros/src/derives.rs | 111 ++++++++++++++++++------------------- 1 file changed, 55 insertions(+), 56 deletions(-) diff --git a/sqlx-macros/src/derives.rs b/sqlx-macros/src/derives.rs index 5716eb8a..8ddc6203 100644 --- a/sqlx-macros/src/derives.rs +++ b/sqlx-macros/src/derives.rs @@ -358,40 +358,34 @@ fn expand_derive_encode_strong_enum( let ident = &input.ident; - let mut tts = proc_macro2::TokenStream::new(); - - if cfg!(feature = "mysql") { - let mut value_arms = Vec::new(); - for v in variants { - let id = &v.ident; - let attributes = parse_attributes(&v.attrs)?; - if let Some(rename) = attributes.rename { - value_arms.push(quote!(#ident :: #id => #rename,)); - } else { - let name = id.to_string(); - value_arms.push(quote!(#ident :: #id => #name,)); - } + let mut value_arms = Vec::new(); + for v in variants { + let id = &v.ident; + let attributes = parse_attributes(&v.attrs)?; + if let Some(rename) = attributes.rename { + value_arms.push(quote!(#ident :: #id => #rename,)); + } else { + let name = id.to_string(); + value_arms.push(quote!(#ident :: #id => #name,)); } - - tts.extend(quote!( - impl sqlx::encode::Encode for #ident where str: sqlx::encode::Encode { - fn encode(&self, buf: &mut std::vec::Vec) { - let val = match self { - #(#value_arms)* - }; - >::encode(val, buf) - } - fn size_hint(&self) -> usize { - let val = match self { - #(#value_arms)* - }; - >::size_hint(val) - } - } - )); } - Ok(tts) + Ok(quote!( + impl sqlx::encode::Encode for #ident where str: sqlx::encode::Encode { + fn encode(&self, buf: &mut std::vec::Vec) { + let val = match self { + #(#value_arms)* + }; + >::encode(val, buf) + } + fn size_hint(&self) -> usize { + let val = match self { + #(#value_arms)* + }; + >::size_hint(val) + } + } + )) } fn expand_derive_encode_struct( @@ -579,7 +573,7 @@ fn expand_derive_decode_strong_enum( // TODO: prevent heap allocation Ok(quote!( - impl sqlx::decode::Decode for #ident where String: sqlx::decode::Decode { + impl sqlx::decode::Decode for #ident where std::string::String: sqlx::decode::Decode { fn decode(buf: &[u8]) -> std::result::Result { let val = >::decode(buf)?; match val.as_str() { @@ -597,31 +591,34 @@ fn expand_derive_decode_struct( ) -> syn::Result { check_struct_attributes(input, fields)?; - let ident = &input.ident; + let mut tts = proc_macro2::TokenStream::new(); - let column_count = fields.len(); + if cfg!(feature = "postgres") { + let ident = &input.ident; - // extract type generics - let generics = &input.generics; - let (_, ty_generics, _) = generics.split_for_impl(); + let column_count = fields.len(); - // add db type for impl generics & where clause - let mut generics = generics.clone(); - let predicates = &mut generics.make_where_clause().predicates; - for field in fields { - let ty = &field.ty; - predicates.push(parse_quote!(#ty: sqlx::decode::Decode)); - predicates.push(parse_quote!(sqlx::Postgres: sqlx::types::HasSqlType<#ty>)); - } - let (impl_generics, _, where_clause) = generics.split_for_impl(); + // extract type generics + let generics = &input.generics; + let (_, ty_generics, _) = generics.split_for_impl(); - let mut reads: Vec> = Vec::new(); - let mut names: Vec = Vec::new(); - for field in fields { - let id = &field.ident; - names.push(id.clone().unwrap()); - let ty = &field.ty; - reads.push(parse_quote!( + // add db type for impl generics & where clause + let mut generics = generics.clone(); + let predicates = &mut generics.make_where_clause().predicates; + for field in fields { + let ty = &field.ty; + predicates.push(parse_quote!(#ty: sqlx::decode::Decode)); + predicates.push(parse_quote!(sqlx::Postgres: sqlx::types::HasSqlType<#ty>)); + } + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + let mut reads: Vec> = Vec::new(); + let mut names: Vec = Vec::new(); + for field in fields { + let id = &field.ident; + names.push(id.clone().unwrap()); + let ty = &field.ty; + reads.push(parse_quote!( if buf.len() < 8 { return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Not enough data sent"))); } @@ -642,10 +639,10 @@ fn expand_derive_decode_struct( let buf = &buf[8+len..]; )); - } - let reads = reads.into_iter().flatten(); + } + let reads = reads.into_iter().flatten(); - Ok(quote!( + tts.extend(quote!( impl #impl_generics sqlx::decode::Decode for #ident#ty_generics #where_clause { fn decode(buf: &[u8]) -> std::result::Result { if buf.len() < 4 { @@ -670,6 +667,8 @@ fn expand_derive_decode_struct( } } )) + } + Ok(tts) } pub(crate) fn expand_derive_has_sql_type( From e603f5fcf617add03fbdcf35177e51f63bca452e Mon Sep 17 00:00:00 2001 From: freax13 Date: Thu, 30 Jan 2020 12:59:09 +0100 Subject: [PATCH 05/25] split derives into different files --- sqlx-macros/src/derives.rs | 846 ------------------------ sqlx-macros/src/derives/attributes.rs | 261 ++++++++ sqlx-macros/src/derives/decode.rs | 221 +++++++ sqlx-macros/src/derives/encode.rs | 208 ++++++ sqlx-macros/src/derives/has_sql_type.rs | 169 +++++ sqlx-macros/src/derives/mod.rs | 25 + 6 files changed, 884 insertions(+), 846 deletions(-) delete mode 100644 sqlx-macros/src/derives.rs create mode 100644 sqlx-macros/src/derives/attributes.rs create mode 100644 sqlx-macros/src/derives/decode.rs create mode 100644 sqlx-macros/src/derives/encode.rs create mode 100644 sqlx-macros/src/derives/has_sql_type.rs create mode 100644 sqlx-macros/src/derives/mod.rs diff --git a/sqlx-macros/src/derives.rs b/sqlx-macros/src/derives.rs deleted file mode 100644 index 8ddc6203..00000000 --- a/sqlx-macros/src/derives.rs +++ /dev/null @@ -1,846 +0,0 @@ -use proc_macro2::Ident; -use quote::quote; -use std::iter::FromIterator; -use syn::punctuated::Punctuated; -use syn::token::Comma; -use syn::{ - parse_quote, Arm, Attribute, Block, Data, DataEnum, DataStruct, DeriveInput, Expr, Field, - Fields, FieldsNamed, FieldsUnnamed, Lit, Meta, MetaNameValue, NestedMeta, Stmt, Variant, -}; - -macro_rules! assert_attribute { - ($e:expr, $err:expr, $input:expr) => { - if !$e { - return Err(syn::Error::new_spanned($input, $err)); - } - }; -} - -struct SqlxAttributes { - transparent: bool, - postgres_oid: Option, - repr: Option, - rename: Option, -} - -fn parse_attributes(input: &[Attribute]) -> syn::Result { - let mut transparent = None; - let mut postgres_oid = None; - let mut repr = None; - let mut rename = None; - - macro_rules! fail { - ($t:expr, $m:expr) => { - return Err(syn::Error::new_spanned($t, $m)); - }; - } - - macro_rules! try_set { - ($i:ident, $v:expr, $t:expr) => { - match $i { - None => $i = Some($v), - Some(_) => fail!($t, "duplicate attribute"), - } - }; - } - - for attr in input { - let meta = attr - .parse_meta() - .map_err(|e| syn::Error::new_spanned(attr, e))?; - match meta { - Meta::List(list) if list.path.is_ident("sqlx") => { - for value in list.nested.iter() { - match value { - NestedMeta::Meta(meta) => match meta { - Meta::Path(p) if p.is_ident("transparent") => { - try_set!(transparent, true, value) - } - Meta::NameValue(MetaNameValue { - path, - lit: Lit::Str(val), - .. - }) if path.is_ident("rename") => try_set!(rename, val.value(), value), - Meta::List(list) if list.path.is_ident("postgres") => { - for value in list.nested.iter() { - match value { - NestedMeta::Meta(Meta::NameValue(MetaNameValue { - path, - lit: Lit::Int(val), - .. - })) if path.is_ident("oid") => { - try_set!(postgres_oid, val.base10_parse()?, value); - } - u => fail!(u, "unexpected value"), - } - } - } - - u => fail!(u, "unexpected attribute"), - }, - u => fail!(u, "unexpected attribute"), - } - } - } - Meta::List(list) if list.path.is_ident("repr") => { - if list.nested.len() != 1 { - fail!(&list.nested, "expected one value") - } - match list.nested.first().unwrap() { - NestedMeta::Meta(Meta::Path(p)) if p.get_ident().is_some() => { - try_set!(repr, p.get_ident().unwrap().clone(), list); - } - u => fail!(u, "unexpected value"), - } - } - _ => {} - } - } - - Ok(SqlxAttributes { - transparent: transparent.unwrap_or(false), - postgres_oid, - repr, - rename, - }) -} - -fn check_transparent_attributes(input: &DeriveInput, field: &Field) -> syn::Result<()> { - let attributes = parse_attributes(&input.attrs)?; - assert_attribute!( - attributes.transparent, - "expected #[sqlx(transparent)]", - input - ); - #[cfg(feature = "postgres")] - assert_attribute!( - attributes.postgres_oid.is_none(), - "unexpected #[sqlx(postgres(oid = ..))]", - input - ); - assert_attribute!( - attributes.rename.is_none(), - "unexpected #[sqlx(rename = ..)]", - field - ); - assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input); - let attributes = parse_attributes(&field.attrs)?; - assert_attribute!( - !attributes.transparent, - "unexpected #[sqlx(transparent)]", - field - ); - #[cfg(feature = "postgres")] - assert_attribute!( - attributes.postgres_oid.is_none(), - "unexpected #[sqlx(postgres(oid = ..))]", - field - ); - assert_attribute!( - attributes.rename.is_none(), - "unexpected #[sqlx(rename = ..)]", - field - ); - assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", field); - Ok(()) -} - -fn check_enum_attributes<'a>( - input: &'a DeriveInput, - variants: &Punctuated, -) -> syn::Result { - let attributes = parse_attributes(&input.attrs)?; - assert_attribute!( - !attributes.transparent, - "unexpected #[sqlx(transparent)]", - input - ); - assert_attribute!( - attributes.rename.is_none(), - "unexpected #[sqlx(rename = ..)]", - input - ); - - for variant in variants { - let attributes = parse_attributes(&variant.attrs)?; - assert_attribute!( - !attributes.transparent, - "unexpected #[sqlx(transparent)]", - variant - ); - #[cfg(feature = "postgres")] - assert_attribute!( - attributes.postgres_oid.is_none(), - "unexpected #[sqlx(postgres(oid = ..))]", - variant - ); - assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", variant); - } - - Ok(attributes) -} - -fn check_weak_enum_attributes( - input: &DeriveInput, - variants: &Punctuated, -) -> syn::Result { - let attributes = check_enum_attributes(input, variants)?; - #[cfg(feature = "postgres")] - assert_attribute!( - attributes.postgres_oid.is_none(), - "unexpected #[sqlx(postgres(oid = ..))]", - input - ); - assert_attribute!(attributes.repr.is_some(), "expected #[repr(..)]", input); - for variant in variants { - let attributes = parse_attributes(&variant.attrs)?; - assert_attribute!( - attributes.rename.is_none(), - "unexpected #[sqlx(rename = ..)]", - variant - ); - } - Ok(attributes.repr.unwrap()) -} - -fn check_strong_enum_attributes( - input: &DeriveInput, - variants: &Punctuated, -) -> syn::Result { - let attributes = check_enum_attributes(input, variants)?; - #[cfg(feature = "postgres")] - assert_attribute!( - attributes.postgres_oid.is_some(), - "expected #[sqlx(postgres(oid = ..))]", - input - ); - assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input); - Ok(attributes) -} - -fn check_struct_attributes<'a>( - input: &'a DeriveInput, - fields: &Punctuated, -) -> syn::Result { - let attributes = parse_attributes(&input.attrs)?; - assert_attribute!( - !attributes.transparent, - "unexpected #[sqlx(transparent)]", - input - ); - #[cfg(feature = "postgres")] - assert_attribute!( - attributes.postgres_oid.is_some(), - "expected #[sqlx(postgres(oid = ..))]", - input - ); - assert_attribute!( - attributes.rename.is_none(), - "unexpected #[sqlx(rename = ..)]", - input - ); - assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input); - - for field in fields { - let attributes = parse_attributes(&field.attrs)?; - assert_attribute!( - !attributes.transparent, - "unexpected #[sqlx(transparent)]", - field - ); - #[cfg(feature = "postgres")] - assert_attribute!( - attributes.postgres_oid.is_none(), - "unexpected #[sqlx(postgres(oid = ..))]", - field - ); - assert_attribute!( - attributes.rename.is_none(), - "unexpected #[sqlx(rename = ..)]", - field - ); - assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", field); - } - - Ok(attributes) -} - -pub(crate) fn expand_derive_encode(input: &DeriveInput) -> syn::Result { - let args = parse_attributes(&input.attrs)?; - - match &input.data { - Data::Struct(DataStruct { - fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), - .. - }) if unnamed.len() == 1 => { - expand_derive_encode_transparent(&input, unnamed.first().unwrap()) - } - Data::Enum(DataEnum { variants, .. }) => match args.repr { - Some(_) => expand_derive_encode_weak_enum(input, variants), - None => expand_derive_encode_strong_enum(input, variants), - }, - Data::Struct(DataStruct { - fields: Fields::Named(FieldsNamed { named, .. }), - .. - }) => expand_derive_encode_struct(input, named), - _ => Err(syn::Error::new_spanned( - input, - "expected a tuple struct with a single field", - )), - } -} - -fn expand_derive_encode_transparent( - input: &DeriveInput, - field: &Field, -) -> syn::Result { - check_transparent_attributes(input, field)?; - - let ident = &input.ident; - let ty = &field.ty; - - // extract type generics - let generics = &input.generics; - let (_, ty_generics, _) = generics.split_for_impl(); - - // add db type for impl generics & where clause - let mut generics = generics.clone(); - generics.params.insert(0, parse_quote!(DB: sqlx::Database)); - generics - .make_where_clause() - .predicates - .push(parse_quote!(#ty: sqlx::encode::Encode)); - let (impl_generics, _, where_clause) = generics.split_for_impl(); - - Ok(quote!( - impl #impl_generics sqlx::encode::Encode for #ident #ty_generics #where_clause { - fn encode(&self, buf: &mut std::vec::Vec) { - sqlx::encode::Encode::encode(&self.0, buf) - } - fn encode_nullable(&self, buf: &mut std::vec::Vec) -> sqlx::encode::IsNull { - sqlx::encode::Encode::encode_nullable(&self.0, buf) - } - fn size_hint(&self) -> usize { - sqlx::encode::Encode::size_hint(&self.0) - } - } - )) -} - -fn expand_derive_encode_weak_enum( - input: &DeriveInput, - variants: &Punctuated, -) -> syn::Result { - let repr = check_weak_enum_attributes(input, &variants)?; - - let ident = &input.ident; - - Ok(quote!( - impl sqlx::encode::Encode for #ident where #repr: sqlx::encode::Encode { - fn encode(&self, buf: &mut std::vec::Vec) { - sqlx::encode::Encode::encode(&(*self as #repr), buf) - } - fn encode_nullable(&self, buf: &mut std::vec::Vec) -> sqlx::encode::IsNull { - sqlx::encode::Encode::encode_nullable(&(*self as #repr), buf) - } - fn size_hint(&self) -> usize { - sqlx::encode::Encode::size_hint(&(*self as #repr)) - } - } - )) -} - -fn expand_derive_encode_strong_enum( - input: &DeriveInput, - variants: &Punctuated, -) -> syn::Result { - check_strong_enum_attributes(input, &variants)?; - - let ident = &input.ident; - - let mut value_arms = Vec::new(); - for v in variants { - let id = &v.ident; - let attributes = parse_attributes(&v.attrs)?; - if let Some(rename) = attributes.rename { - value_arms.push(quote!(#ident :: #id => #rename,)); - } else { - let name = id.to_string(); - value_arms.push(quote!(#ident :: #id => #name,)); - } - } - - Ok(quote!( - impl sqlx::encode::Encode for #ident where str: sqlx::encode::Encode { - fn encode(&self, buf: &mut std::vec::Vec) { - let val = match self { - #(#value_arms)* - }; - >::encode(val, buf) - } - fn size_hint(&self) -> usize { - let val = match self { - #(#value_arms)* - }; - >::size_hint(val) - } - } - )) -} - -fn expand_derive_encode_struct( - input: &DeriveInput, - fields: &Punctuated, -) -> syn::Result { - check_struct_attributes(input, &fields)?; - - let mut tts = proc_macro2::TokenStream::new(); - - if cfg!(feature = "postgres") { - let ident = &input.ident; - - let column_count = fields.len(); - - // extract type generics - let generics = &input.generics; - let (_, ty_generics, _) = generics.split_for_impl(); - - // add db type for impl generics & where clause - let mut generics = generics.clone(); - let predicates = &mut generics.make_where_clause().predicates; - for field in fields { - let ty = &field.ty; - predicates.push(parse_quote!(#ty: sqlx::encode::Encode)); - predicates.push(parse_quote!(sqlx::Postgres: sqlx::types::HasSqlType<#ty>)); - } - let (impl_generics, _, where_clause) = generics.split_for_impl(); - - let mut writes: Vec = Vec::new(); - for field in fields { - let id = &field.ident; - let ty = &field.ty; - writes.push(parse_quote!({ - // write oid - let info = >::type_info(); - buf.extend(&info.oid().to_be_bytes()); - - // write zeros for length - buf.extend(&[0; 4]); - - let start = buf.len(); - sqlx::encode::Encode::::encode(&self. #id, buf); - let end = buf.len(); - let size = end - start; - - // replaces zeros with actual length - buf[start-4..start].copy_from_slice(&(size as u32).to_be_bytes()); - })); - } - - let mut sizes: Vec = Vec::new(); - for field in fields { - let id = &field.ident; - let ty = &field.ty; - sizes.push( - parse_quote!(<#ty as sqlx::encode::Encode>::size_hint(&self. #id)), - ); - } - - tts.extend(quote!( - impl #impl_generics sqlx::encode::Encode for #ident #ty_generics #where_clause { - fn encode(&self, buf: &mut std::vec::Vec) { - buf.extend(&(#column_count as u32).to_be_bytes()); - #(#writes)* - } - fn size_hint(&self) -> usize { - 4 + #column_count * (4 + 4) + #(#sizes)+* - } - } - )); - } - - Ok(tts) -} - -pub(crate) fn expand_derive_decode(input: &DeriveInput) -> syn::Result { - let attrs = parse_attributes(&input.attrs)?; - match &input.data { - Data::Struct(DataStruct { - fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), - .. - }) if unnamed.len() == 1 => { - expand_derive_decode_transparent(input, unnamed.first().unwrap()) - } - Data::Enum(DataEnum { variants, .. }) => match attrs.repr { - Some(_) => expand_derive_decode_weak_enum(input, variants), - None => expand_derive_decode_strong_enum(input, variants), - }, - Data::Struct(DataStruct { - fields: Fields::Named(FieldsNamed { named, .. }), - .. - }) => expand_derive_decode_struct(input, named), - _ => Err(syn::Error::new_spanned( - input, - "expected a tuple struct with a single field", - )), - } -} - -fn expand_derive_decode_transparent( - input: &DeriveInput, - field: &Field, -) -> syn::Result { - check_transparent_attributes(input, field)?; - - let ident = &input.ident; - let ty = &field.ty; - - // extract type generics - let generics = &input.generics; - let (_, ty_generics, _) = generics.split_for_impl(); - - // add db type for impl generics & where clause - let mut generics = generics.clone(); - generics.params.insert(0, parse_quote!(DB: sqlx::Database)); - generics - .make_where_clause() - .predicates - .push(parse_quote!(#ty: sqlx::decode::Decode)); - let (impl_generics, _, where_clause) = generics.split_for_impl(); - - Ok(quote!( - impl #impl_generics sqlx::decode::Decode for #ident #ty_generics #where_clause { - fn decode(raw: &[u8]) -> std::result::Result { - <#ty as sqlx::decode::Decode>::decode(raw).map(Self) - } - fn decode_null() -> std::result::Result { - <#ty as sqlx::decode::Decode>::decode_null().map(Self) - } - fn decode_nullable(raw: std::option::Option<&[u8]>) -> std::result::Result { - <#ty as sqlx::decode::Decode>::decode_nullable(raw).map(Self) - } - } - )) -} - -fn expand_derive_decode_weak_enum( - input: &DeriveInput, - variants: &Punctuated, -) -> syn::Result { - let repr = check_weak_enum_attributes(input, &variants)?; - - let ident = &input.ident; - let arms = variants - .iter() - .map(|v| { - let id = &v.ident; - parse_quote!(_ if (#ident :: #id as #repr) == val => Ok(#ident :: #id),) - }) - .collect::>(); - - Ok(quote!( - impl sqlx::decode::Decode for #ident where #repr: sqlx::decode::Decode { - fn decode(raw: &[u8]) -> std::result::Result { - let val = <#repr as sqlx::decode::Decode>::decode(raw)?; - match val { - #(#arms)* - _ => Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid value"))) - } - } - } - )) -} - -fn expand_derive_decode_strong_enum( - input: &DeriveInput, - variants: &Punctuated, -) -> syn::Result { - check_strong_enum_attributes(input, &variants)?; - - let ident = &input.ident; - - let mut value_arms = Vec::new(); - for v in variants { - let id = &v.ident; - let attributes = parse_attributes(&v.attrs)?; - if let Some(rename) = attributes.rename { - value_arms.push(quote!(#rename => Ok(#ident :: #id),)); - } else { - let name = id.to_string(); - value_arms.push(quote!(#name => Ok(#ident :: #id),)); - } - } - - // TODO: prevent heap allocation - Ok(quote!( - impl sqlx::decode::Decode for #ident where std::string::String: sqlx::decode::Decode { - fn decode(buf: &[u8]) -> std::result::Result { - let val = >::decode(buf)?; - match val.as_str() { - #(#value_arms)* - _ => Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid value"))) - } - } - } - )) -} - -fn expand_derive_decode_struct( - input: &DeriveInput, - fields: &Punctuated, -) -> syn::Result { - check_struct_attributes(input, fields)?; - - let mut tts = proc_macro2::TokenStream::new(); - - if cfg!(feature = "postgres") { - let ident = &input.ident; - - let column_count = fields.len(); - - // extract type generics - let generics = &input.generics; - let (_, ty_generics, _) = generics.split_for_impl(); - - // add db type for impl generics & where clause - let mut generics = generics.clone(); - let predicates = &mut generics.make_where_clause().predicates; - for field in fields { - let ty = &field.ty; - predicates.push(parse_quote!(#ty: sqlx::decode::Decode)); - predicates.push(parse_quote!(sqlx::Postgres: sqlx::types::HasSqlType<#ty>)); - } - let (impl_generics, _, where_clause) = generics.split_for_impl(); - - let mut reads: Vec> = Vec::new(); - let mut names: Vec = Vec::new(); - for field in fields { - let id = &field.ident; - names.push(id.clone().unwrap()); - let ty = &field.ty; - reads.push(parse_quote!( - if buf.len() < 8 { - return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Not enough data sent"))); - } - - let oid = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[0..4]).unwrap()); - if oid != >::type_info().oid() { - return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid oid"))); - } - - let len = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[4..8]).unwrap()) as usize; - - if buf.len() < 8 + len { - return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Not enough data sent"))); - } - - let raw = &buf[8..8+len]; - let #id = <#ty as sqlx::decode::Decode>::decode(raw)?; - - let buf = &buf[8+len..]; - )); - } - let reads = reads.into_iter().flatten(); - - tts.extend(quote!( - impl #impl_generics sqlx::decode::Decode for #ident#ty_generics #where_clause { - fn decode(buf: &[u8]) -> std::result::Result { - if buf.len() < 4 { - return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Not enough data sent"))); - } - - let column_count = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[..4]).unwrap()) as usize; - if column_count != #column_count { - return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid column count"))); - } - let buf = &buf[4..]; - - #(#reads)* - - if !buf.is_empty() { - return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new(format!("Too much data sent ({} bytes left)", buf.len())))); - } - - Ok(#ident { - #(#names),* - }) - } - } - )) - } - Ok(tts) -} - -pub(crate) fn expand_derive_has_sql_type( - input: &DeriveInput, -) -> syn::Result { - let attrs = parse_attributes(&input.attrs)?; - match &input.data { - Data::Struct(DataStruct { - fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), - .. - }) if unnamed.len() == 1 => { - expand_derive_has_sql_type_transparent(input, unnamed.first().unwrap()) - } - Data::Enum(DataEnum { variants, .. }) => match attrs.repr { - Some(_) => expand_derive_has_sql_type_weak_enum(input, variants), - None => expand_derive_has_sql_type_strong_enum(input, variants), - }, - Data::Struct(DataStruct { - fields: Fields::Named(FieldsNamed { named, .. }), - .. - }) => expand_derive_has_sql_type_struct(input, named), - _ => Err(syn::Error::new_spanned( - input, - "expected a tuple struct with a single field", - )), - } -} - -fn expand_derive_has_sql_type_transparent( - input: &DeriveInput, - field: &Field, -) -> syn::Result { - check_transparent_attributes(input, field)?; - - let ident = &input.ident; - let ty = &field.ty; - - // extract type generics - let generics = &input.generics; - let (impl_generics, ty_generics, _) = generics.split_for_impl(); - - // add db type for clause - let mut generics = generics.clone(); - generics - .make_where_clause() - .predicates - .push(parse_quote!(Self: sqlx::types::HasSqlType<#ty>)); - let (_, _, where_clause) = generics.split_for_impl(); - - let mut tts = proc_macro2::TokenStream::new(); - - if cfg!(feature = "mysql") { - tts.extend(quote!( - impl #impl_generics sqlx::types::HasSqlType< #ident #ty_generics > for sqlx::MySql #where_clause { - fn type_info() -> Self::TypeInfo { - >::type_info() - } - } - )); - } - - if cfg!(feature = "postgres") { - tts.extend(quote!( - impl #impl_generics sqlx::types::HasSqlType< #ident #ty_generics > for sqlx::Postgres #where_clause { - fn type_info() -> Self::TypeInfo { - >::type_info() - } - } - )); - } - - Ok(tts) -} - -fn expand_derive_has_sql_type_weak_enum( - input: &DeriveInput, - variants: &Punctuated, -) -> syn::Result { - let repr = check_weak_enum_attributes(input, variants)?; - - let ident = &input.ident; - let mut tts = proc_macro2::TokenStream::new(); - - if cfg!(feature = "mysql") { - tts.extend(quote!( - impl sqlx::types::HasSqlType< #ident > for sqlx::MySql where Self: sqlx::types::HasSqlType< #repr > { - fn type_info() -> Self::TypeInfo { - >::type_info() - } - } - )); - } - - if cfg!(feature = "postgres") { - tts.extend(quote!( - impl sqlx::types::HasSqlType< #ident > for sqlx::Postgres where Self: sqlx::types::HasSqlType< #repr > { - fn type_info() -> Self::TypeInfo { - >::type_info() - } - } - )); - } - - Ok(tts) -} - -fn expand_derive_has_sql_type_strong_enum( - input: &DeriveInput, - variants: &Punctuated, -) -> syn::Result { - let attributes = check_strong_enum_attributes(input, variants)?; - - let ident = &input.ident; - let mut tts = proc_macro2::TokenStream::new(); - - if cfg!(feature = "mysql") { - tts.extend(quote!( - impl sqlx::types::HasSqlType< #ident > for sqlx::MySql { - fn type_info() -> Self::TypeInfo { - sqlx::mysql::MySqlTypeInfo::r#enum() - } - } - )); - } - - if cfg!(feature = "postgres") { - let oid = attributes.postgres_oid.unwrap(); - tts.extend(quote!( - impl sqlx::types::HasSqlType< #ident > for sqlx::Postgres { - fn type_info() -> Self::TypeInfo { - sqlx::postgres::PgTypeInfo::with_oid(#oid) - } - } - )); - } - - Ok(tts) -} - -fn expand_derive_has_sql_type_struct( - input: &DeriveInput, - fields: &Punctuated, -) -> syn::Result { - let attributes = check_struct_attributes(input, fields)?; - - let ident = &input.ident; - let mut tts = proc_macro2::TokenStream::new(); - - if cfg!(feature = "postgres") { - let oid = attributes.postgres_oid.unwrap(); - tts.extend(quote!( - impl sqlx::types::HasSqlType< #ident > for sqlx::Postgres { - fn type_info() -> Self::TypeInfo { - sqlx::postgres::PgTypeInfo::with_oid(#oid) - } - } - )); - } - - Ok(tts) -} - -pub(crate) fn expand_derive_type(input: &DeriveInput) -> syn::Result { - let encode_tts = expand_derive_encode(input)?; - let decode_tts = expand_derive_decode(input)?; - let has_sql_type_tts = expand_derive_has_sql_type(input)?; - - let combined = proc_macro2::TokenStream::from_iter( - encode_tts - .into_iter() - .chain(decode_tts) - .chain(has_sql_type_tts), - ); - Ok(combined) -} diff --git a/sqlx-macros/src/derives/attributes.rs b/sqlx-macros/src/derives/attributes.rs new file mode 100644 index 00000000..72df6991 --- /dev/null +++ b/sqlx-macros/src/derives/attributes.rs @@ -0,0 +1,261 @@ +use proc_macro2::Ident; +use syn::punctuated::Punctuated; +use syn::token::Comma; +use syn::{Attribute, DeriveInput, Field, Lit, Meta, MetaNameValue, NestedMeta, Variant}; + +macro_rules! assert_attribute { + ($e:expr, $err:expr, $input:expr) => { + if !$e { + return Err(syn::Error::new_spanned($input, $err)); + } + }; +} + +pub struct SqlxAttributes { + pub transparent: bool, + pub postgres_oid: Option, + pub repr: Option, + pub rename: Option, +} + +pub fn parse_attributes(input: &[Attribute]) -> syn::Result { + let mut transparent = None; + let mut postgres_oid = None; + let mut repr = None; + let mut rename = None; + + macro_rules! fail { + ($t:expr, $m:expr) => { + return Err(syn::Error::new_spanned($t, $m)); + }; + } + + macro_rules! try_set { + ($i:ident, $v:expr, $t:expr) => { + match $i { + None => $i = Some($v), + Some(_) => fail!($t, "duplicate attribute"), + } + }; + } + + for attr in input { + let meta = attr + .parse_meta() + .map_err(|e| syn::Error::new_spanned(attr, e))?; + match meta { + Meta::List(list) if list.path.is_ident("sqlx") => { + for value in list.nested.iter() { + match value { + NestedMeta::Meta(meta) => match meta { + Meta::Path(p) if p.is_ident("transparent") => { + try_set!(transparent, true, value) + } + Meta::NameValue(MetaNameValue { + path, + lit: Lit::Str(val), + .. + }) if path.is_ident("rename") => try_set!(rename, val.value(), value), + Meta::List(list) if list.path.is_ident("postgres") => { + for value in list.nested.iter() { + match value { + NestedMeta::Meta(Meta::NameValue(MetaNameValue { + path, + lit: Lit::Int(val), + .. + })) if path.is_ident("oid") => { + try_set!(postgres_oid, val.base10_parse()?, value); + } + u => fail!(u, "unexpected value"), + } + } + } + + u => fail!(u, "unexpected attribute"), + }, + u => fail!(u, "unexpected attribute"), + } + } + } + Meta::List(list) if list.path.is_ident("repr") => { + if list.nested.len() != 1 { + fail!(&list.nested, "expected one value") + } + match list.nested.first().unwrap() { + NestedMeta::Meta(Meta::Path(p)) if p.get_ident().is_some() => { + try_set!(repr, p.get_ident().unwrap().clone(), list); + } + u => fail!(u, "unexpected value"), + } + } + _ => {} + } + } + + Ok(SqlxAttributes { + transparent: transparent.unwrap_or(false), + postgres_oid, + repr, + rename, + }) +} + +pub fn check_transparent_attributes(input: &DeriveInput, field: &Field) -> syn::Result<()> { + let attributes = parse_attributes(&input.attrs)?; + assert_attribute!( + attributes.transparent, + "expected #[sqlx(transparent)]", + input + ); + #[cfg(feature = "postgres")] + assert_attribute!( + attributes.postgres_oid.is_none(), + "unexpected #[sqlx(postgres(oid = ..))]", + input + ); + assert_attribute!( + attributes.rename.is_none(), + "unexpected #[sqlx(rename = ..)]", + field + ); + assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input); + let attributes = parse_attributes(&field.attrs)?; + assert_attribute!( + !attributes.transparent, + "unexpected #[sqlx(transparent)]", + field + ); + #[cfg(feature = "postgres")] + assert_attribute!( + attributes.postgres_oid.is_none(), + "unexpected #[sqlx(postgres(oid = ..))]", + field + ); + assert_attribute!( + attributes.rename.is_none(), + "unexpected #[sqlx(rename = ..)]", + field + ); + assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", field); + Ok(()) +} + +pub fn check_enum_attributes<'a>( + input: &'a DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let attributes = parse_attributes(&input.attrs)?; + assert_attribute!( + !attributes.transparent, + "unexpected #[sqlx(transparent)]", + input + ); + assert_attribute!( + attributes.rename.is_none(), + "unexpected #[sqlx(rename = ..)]", + input + ); + + for variant in variants { + let attributes = parse_attributes(&variant.attrs)?; + assert_attribute!( + !attributes.transparent, + "unexpected #[sqlx(transparent)]", + variant + ); + #[cfg(feature = "postgres")] + assert_attribute!( + attributes.postgres_oid.is_none(), + "unexpected #[sqlx(postgres(oid = ..))]", + variant + ); + assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", variant); + } + + Ok(attributes) +} + +pub fn check_weak_enum_attributes( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let attributes = check_enum_attributes(input, variants)?; + #[cfg(feature = "postgres")] + assert_attribute!( + attributes.postgres_oid.is_none(), + "unexpected #[sqlx(postgres(oid = ..))]", + input + ); + assert_attribute!(attributes.repr.is_some(), "expected #[repr(..)]", input); + for variant in variants { + let attributes = parse_attributes(&variant.attrs)?; + assert_attribute!( + attributes.rename.is_none(), + "unexpected #[sqlx(rename = ..)]", + variant + ); + } + Ok(attributes.repr.unwrap()) +} + +pub fn check_strong_enum_attributes( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let attributes = check_enum_attributes(input, variants)?; + #[cfg(feature = "postgres")] + assert_attribute!( + attributes.postgres_oid.is_some(), + "expected #[sqlx(postgres(oid = ..))]", + input + ); + assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input); + Ok(attributes) +} + +pub fn check_struct_attributes<'a>( + input: &'a DeriveInput, + fields: &Punctuated, +) -> syn::Result { + let attributes = parse_attributes(&input.attrs)?; + assert_attribute!( + !attributes.transparent, + "unexpected #[sqlx(transparent)]", + input + ); + #[cfg(feature = "postgres")] + assert_attribute!( + attributes.postgres_oid.is_some(), + "expected #[sqlx(postgres(oid = ..))]", + input + ); + assert_attribute!( + attributes.rename.is_none(), + "unexpected #[sqlx(rename = ..)]", + input + ); + assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input); + + for field in fields { + let attributes = parse_attributes(&field.attrs)?; + assert_attribute!( + !attributes.transparent, + "unexpected #[sqlx(transparent)]", + field + ); + #[cfg(feature = "postgres")] + assert_attribute!( + attributes.postgres_oid.is_none(), + "unexpected #[sqlx(postgres(oid = ..))]", + field + ); + assert_attribute!( + attributes.rename.is_none(), + "unexpected #[sqlx(rename = ..)]", + field + ); + assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", field); + } + + Ok(attributes) +} diff --git a/sqlx-macros/src/derives/decode.rs b/sqlx-macros/src/derives/decode.rs new file mode 100644 index 00000000..2aa779f8 --- /dev/null +++ b/sqlx-macros/src/derives/decode.rs @@ -0,0 +1,221 @@ +use super::attributes::{ + check_strong_enum_attributes, check_struct_attributes, check_transparent_attributes, + check_weak_enum_attributes, parse_attributes, +}; +use proc_macro2::Ident; +use quote::quote; +use syn::punctuated::Punctuated; +use syn::token::Comma; +use syn::{ + parse_quote, Arm, Data, DataEnum, DataStruct, DeriveInput, Field, Fields, FieldsNamed, + FieldsUnnamed, Stmt, Variant, +}; + +pub fn expand_derive_decode(input: &DeriveInput) -> syn::Result { + let attrs = parse_attributes(&input.attrs)?; + match &input.data { + Data::Struct(DataStruct { + fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), + .. + }) if unnamed.len() == 1 => { + expand_derive_decode_transparent(input, unnamed.first().unwrap()) + } + Data::Enum(DataEnum { variants, .. }) => match attrs.repr { + Some(_) => expand_derive_decode_weak_enum(input, variants), + None => expand_derive_decode_strong_enum(input, variants), + }, + Data::Struct(DataStruct { + fields: Fields::Named(FieldsNamed { named, .. }), + .. + }) => expand_derive_decode_struct(input, named), + _ => Err(syn::Error::new_spanned( + input, + "expected a tuple struct with a single field", + )), + } +} + +fn expand_derive_decode_transparent( + input: &DeriveInput, + field: &Field, +) -> syn::Result { + check_transparent_attributes(input, field)?; + + let ident = &input.ident; + let ty = &field.ty; + + // extract type generics + let generics = &input.generics; + let (_, ty_generics, _) = generics.split_for_impl(); + + // add db type for impl generics & where clause + let mut generics = generics.clone(); + generics.params.insert(0, parse_quote!(DB: sqlx::Database)); + generics + .make_where_clause() + .predicates + .push(parse_quote!(#ty: sqlx::decode::Decode)); + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + Ok(quote!( + impl #impl_generics sqlx::decode::Decode for #ident #ty_generics #where_clause { + fn decode(raw: &[u8]) -> std::result::Result { + <#ty as sqlx::decode::Decode>::decode(raw).map(Self) + } + fn decode_null() -> std::result::Result { + <#ty as sqlx::decode::Decode>::decode_null().map(Self) + } + fn decode_nullable(raw: std::option::Option<&[u8]>) -> std::result::Result { + <#ty as sqlx::decode::Decode>::decode_nullable(raw).map(Self) + } + } + )) +} + +fn expand_derive_decode_weak_enum( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let repr = check_weak_enum_attributes(input, &variants)?; + + let ident = &input.ident; + let arms = variants + .iter() + .map(|v| { + let id = &v.ident; + parse_quote!(_ if (#ident :: #id as #repr) == val => Ok(#ident :: #id),) + }) + .collect::>(); + + Ok(quote!( + impl sqlx::decode::Decode for #ident where #repr: sqlx::decode::Decode { + fn decode(raw: &[u8]) -> std::result::Result { + let val = <#repr as sqlx::decode::Decode>::decode(raw)?; + match val { + #(#arms)* + _ => Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid value"))) + } + } + } + )) +} + +fn expand_derive_decode_strong_enum( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + check_strong_enum_attributes(input, &variants)?; + + let ident = &input.ident; + + let mut value_arms = Vec::new(); + for v in variants { + let id = &v.ident; + let attributes = parse_attributes(&v.attrs)?; + if let Some(rename) = attributes.rename { + value_arms.push(quote!(#rename => Ok(#ident :: #id),)); + } else { + let name = id.to_string(); + value_arms.push(quote!(#name => Ok(#ident :: #id),)); + } + } + + // TODO: prevent heap allocation + Ok(quote!( + impl sqlx::decode::Decode for #ident where std::string::String: sqlx::decode::Decode { + fn decode(buf: &[u8]) -> std::result::Result { + let val = >::decode(buf)?; + match val.as_str() { + #(#value_arms)* + _ => Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid value"))) + } + } + } + )) +} + +fn expand_derive_decode_struct( + input: &DeriveInput, + fields: &Punctuated, +) -> syn::Result { + check_struct_attributes(input, fields)?; + + let mut tts = proc_macro2::TokenStream::new(); + + if cfg!(feature = "postgres") { + let ident = &input.ident; + + let column_count = fields.len(); + + // extract type generics + let generics = &input.generics; + let (_, ty_generics, _) = generics.split_for_impl(); + + // add db type for impl generics & where clause + let mut generics = generics.clone(); + let predicates = &mut generics.make_where_clause().predicates; + for field in fields { + let ty = &field.ty; + predicates.push(parse_quote!(#ty: sqlx::decode::Decode)); + predicates.push(parse_quote!(sqlx::Postgres: sqlx::types::HasSqlType<#ty>)); + } + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + let mut reads: Vec> = Vec::new(); + let mut names: Vec = Vec::new(); + for field in fields { + let id = &field.ident; + names.push(id.clone().unwrap()); + let ty = &field.ty; + reads.push(parse_quote!( + if buf.len() < 8 { + return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Not enough data sent"))); + } + + let oid = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[0..4]).unwrap()); + if oid != >::type_info().oid() { + return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid oid"))); + } + + let len = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[4..8]).unwrap()) as usize; + + if buf.len() < 8 + len { + return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Not enough data sent"))); + } + + let raw = &buf[8..8+len]; + let #id = <#ty as sqlx::decode::Decode>::decode(raw)?; + + let buf = &buf[8+len..]; + )); + } + let reads = reads.into_iter().flatten(); + + tts.extend(quote!( + impl #impl_generics sqlx::decode::Decode for #ident#ty_generics #where_clause { + fn decode(buf: &[u8]) -> std::result::Result { + if buf.len() < 4 { + return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Not enough data sent"))); + } + + let column_count = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[..4]).unwrap()) as usize; + if column_count != #column_count { + return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid column count"))); + } + let buf = &buf[4..]; + + #(#reads)* + + if !buf.is_empty() { + return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new(format!("Too much data sent ({} bytes left)", buf.len())))); + } + + Ok(#ident { + #(#names),* + }) + } + } + )) + } + Ok(tts) +} diff --git a/sqlx-macros/src/derives/encode.rs b/sqlx-macros/src/derives/encode.rs new file mode 100644 index 00000000..03c91d00 --- /dev/null +++ b/sqlx-macros/src/derives/encode.rs @@ -0,0 +1,208 @@ +use super::attributes::{ + check_strong_enum_attributes, check_struct_attributes, check_transparent_attributes, + check_weak_enum_attributes, parse_attributes, +}; +use quote::quote; +use syn::punctuated::Punctuated; +use syn::token::Comma; +use syn::{ + parse_quote, Block, Data, DataEnum, DataStruct, DeriveInput, Expr, Field, Fields, FieldsNamed, + FieldsUnnamed, Variant, +}; + +pub fn expand_derive_encode(input: &DeriveInput) -> syn::Result { + let args = parse_attributes(&input.attrs)?; + + match &input.data { + Data::Struct(DataStruct { + fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), + .. + }) if unnamed.len() == 1 => { + expand_derive_encode_transparent(&input, unnamed.first().unwrap()) + } + Data::Enum(DataEnum { variants, .. }) => match args.repr { + Some(_) => expand_derive_encode_weak_enum(input, variants), + None => expand_derive_encode_strong_enum(input, variants), + }, + Data::Struct(DataStruct { + fields: Fields::Named(FieldsNamed { named, .. }), + .. + }) => expand_derive_encode_struct(input, named), + _ => Err(syn::Error::new_spanned( + input, + "expected a tuple struct with a single field", + )), + } +} + +fn expand_derive_encode_transparent( + input: &DeriveInput, + field: &Field, +) -> syn::Result { + check_transparent_attributes(input, field)?; + + let ident = &input.ident; + let ty = &field.ty; + + // extract type generics + let generics = &input.generics; + let (_, ty_generics, _) = generics.split_for_impl(); + + // add db type for impl generics & where clause + let mut generics = generics.clone(); + generics.params.insert(0, parse_quote!(DB: sqlx::Database)); + generics + .make_where_clause() + .predicates + .push(parse_quote!(#ty: sqlx::encode::Encode)); + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + Ok(quote!( + impl #impl_generics sqlx::encode::Encode for #ident #ty_generics #where_clause { + fn encode(&self, buf: &mut std::vec::Vec) { + sqlx::encode::Encode::encode(&self.0, buf) + } + fn encode_nullable(&self, buf: &mut std::vec::Vec) -> sqlx::encode::IsNull { + sqlx::encode::Encode::encode_nullable(&self.0, buf) + } + fn size_hint(&self) -> usize { + sqlx::encode::Encode::size_hint(&self.0) + } + } + )) +} + +fn expand_derive_encode_weak_enum( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let repr = check_weak_enum_attributes(input, &variants)?; + + let ident = &input.ident; + + Ok(quote!( + impl sqlx::encode::Encode for #ident where #repr: sqlx::encode::Encode { + fn encode(&self, buf: &mut std::vec::Vec) { + sqlx::encode::Encode::encode(&(*self as #repr), buf) + } + fn encode_nullable(&self, buf: &mut std::vec::Vec) -> sqlx::encode::IsNull { + sqlx::encode::Encode::encode_nullable(&(*self as #repr), buf) + } + fn size_hint(&self) -> usize { + sqlx::encode::Encode::size_hint(&(*self as #repr)) + } + } + )) +} + +fn expand_derive_encode_strong_enum( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + check_strong_enum_attributes(input, &variants)?; + + let ident = &input.ident; + + let mut value_arms = Vec::new(); + for v in variants { + let id = &v.ident; + let attributes = parse_attributes(&v.attrs)?; + if let Some(rename) = attributes.rename { + value_arms.push(quote!(#ident :: #id => #rename,)); + } else { + let name = id.to_string(); + value_arms.push(quote!(#ident :: #id => #name,)); + } + } + + Ok(quote!( + impl sqlx::encode::Encode for #ident where str: sqlx::encode::Encode { + fn encode(&self, buf: &mut std::vec::Vec) { + let val = match self { + #(#value_arms)* + }; + >::encode(val, buf) + } + fn size_hint(&self) -> usize { + let val = match self { + #(#value_arms)* + }; + >::size_hint(val) + } + } + )) +} + +fn expand_derive_encode_struct( + input: &DeriveInput, + fields: &Punctuated, +) -> syn::Result { + check_struct_attributes(input, &fields)?; + + let mut tts = proc_macro2::TokenStream::new(); + + if cfg!(feature = "postgres") { + let ident = &input.ident; + + let column_count = fields.len(); + + // extract type generics + let generics = &input.generics; + let (_, ty_generics, _) = generics.split_for_impl(); + + // add db type for impl generics & where clause + let mut generics = generics.clone(); + let predicates = &mut generics.make_where_clause().predicates; + for field in fields { + let ty = &field.ty; + predicates.push(parse_quote!(#ty: sqlx::encode::Encode)); + predicates.push(parse_quote!(sqlx::Postgres: sqlx::types::HasSqlType<#ty>)); + } + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + let mut writes: Vec = Vec::new(); + for field in fields { + let id = &field.ident; + let ty = &field.ty; + writes.push(parse_quote!({ + // write oid + let info = >::type_info(); + buf.extend(&info.oid().to_be_bytes()); + + // write zeros for length + buf.extend(&[0; 4]); + + let start = buf.len(); + sqlx::encode::Encode::::encode(&self. #id, buf); + let end = buf.len(); + let size = end - start; + + // replaces zeros with actual length + buf[start-4..start].copy_from_slice(&(size as u32).to_be_bytes()); + })); + } + + let mut sizes: Vec = Vec::new(); + for field in fields { + let id = &field.ident; + let ty = &field.ty; + sizes.push( + parse_quote!(<#ty as sqlx::encode::Encode>::size_hint(&self. #id)), + ); + } + + tts.extend(quote!( + impl #impl_generics sqlx::encode::Encode for #ident #ty_generics #where_clause { + fn encode(&self, buf: &mut std::vec::Vec) { + buf.extend(&(#column_count as u32).to_be_bytes()); + #(#writes)* + } + fn size_hint(&self) -> usize { + 4 + #column_count * (4 + 4) + #(#sizes)+* + } + } + )); + } + + Ok(tts) +} diff --git a/sqlx-macros/src/derives/has_sql_type.rs b/sqlx-macros/src/derives/has_sql_type.rs new file mode 100644 index 00000000..64da9ae1 --- /dev/null +++ b/sqlx-macros/src/derives/has_sql_type.rs @@ -0,0 +1,169 @@ +use super::attributes::{ + check_strong_enum_attributes, check_struct_attributes, check_transparent_attributes, + check_weak_enum_attributes, parse_attributes, +}; +use quote::quote; +use syn::punctuated::Punctuated; +use syn::token::Comma; +use syn::{ + parse_quote, Data, DataEnum, DataStruct, DeriveInput, Field, Fields, FieldsNamed, + FieldsUnnamed, Variant, +}; + +pub fn expand_derive_has_sql_type(input: &DeriveInput) -> syn::Result { + let attrs = parse_attributes(&input.attrs)?; + match &input.data { + Data::Struct(DataStruct { + fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), + .. + }) if unnamed.len() == 1 => { + expand_derive_has_sql_type_transparent(input, unnamed.first().unwrap()) + } + Data::Enum(DataEnum { variants, .. }) => match attrs.repr { + Some(_) => expand_derive_has_sql_type_weak_enum(input, variants), + None => expand_derive_has_sql_type_strong_enum(input, variants), + }, + Data::Struct(DataStruct { + fields: Fields::Named(FieldsNamed { named, .. }), + .. + }) => expand_derive_has_sql_type_struct(input, named), + _ => Err(syn::Error::new_spanned( + input, + "expected a tuple struct with a single field", + )), + } +} + +fn expand_derive_has_sql_type_transparent( + input: &DeriveInput, + field: &Field, +) -> syn::Result { + check_transparent_attributes(input, field)?; + + let ident = &input.ident; + let ty = &field.ty; + + // extract type generics + let generics = &input.generics; + let (impl_generics, ty_generics, _) = generics.split_for_impl(); + + // add db type for clause + let mut generics = generics.clone(); + generics + .make_where_clause() + .predicates + .push(parse_quote!(Self: sqlx::types::HasSqlType<#ty>)); + let (_, _, where_clause) = generics.split_for_impl(); + + let mut tts = proc_macro2::TokenStream::new(); + + if cfg!(feature = "mysql") { + tts.extend(quote!( + impl #impl_generics sqlx::types::HasSqlType< #ident #ty_generics > for sqlx::MySql #where_clause { + fn type_info() -> Self::TypeInfo { + >::type_info() + } + } + )); + } + + if cfg!(feature = "postgres") { + tts.extend(quote!( + impl #impl_generics sqlx::types::HasSqlType< #ident #ty_generics > for sqlx::Postgres #where_clause { + fn type_info() -> Self::TypeInfo { + >::type_info() + } + } + )); + } + + Ok(tts) +} + +fn expand_derive_has_sql_type_weak_enum( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let repr = check_weak_enum_attributes(input, variants)?; + + let ident = &input.ident; + let mut tts = proc_macro2::TokenStream::new(); + + if cfg!(feature = "mysql") { + tts.extend(quote!( + impl sqlx::types::HasSqlType< #ident > for sqlx::MySql where Self: sqlx::types::HasSqlType< #repr > { + fn type_info() -> Self::TypeInfo { + >::type_info() + } + } + )); + } + + if cfg!(feature = "postgres") { + tts.extend(quote!( + impl sqlx::types::HasSqlType< #ident > for sqlx::Postgres where Self: sqlx::types::HasSqlType< #repr > { + fn type_info() -> Self::TypeInfo { + >::type_info() + } + } + )); + } + + Ok(tts) +} + +fn expand_derive_has_sql_type_strong_enum( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let attributes = check_strong_enum_attributes(input, variants)?; + + let ident = &input.ident; + let mut tts = proc_macro2::TokenStream::new(); + + if cfg!(feature = "mysql") { + tts.extend(quote!( + impl sqlx::types::HasSqlType< #ident > for sqlx::MySql { + fn type_info() -> Self::TypeInfo { + sqlx::mysql::MySqlTypeInfo::r#enum() + } + } + )); + } + + if cfg!(feature = "postgres") { + let oid = attributes.postgres_oid.unwrap(); + tts.extend(quote!( + impl sqlx::types::HasSqlType< #ident > for sqlx::Postgres { + fn type_info() -> Self::TypeInfo { + sqlx::postgres::PgTypeInfo::with_oid(#oid) + } + } + )); + } + + Ok(tts) +} + +fn expand_derive_has_sql_type_struct( + input: &DeriveInput, + fields: &Punctuated, +) -> syn::Result { + let attributes = check_struct_attributes(input, fields)?; + + let ident = &input.ident; + let mut tts = proc_macro2::TokenStream::new(); + + if cfg!(feature = "postgres") { + let oid = attributes.postgres_oid.unwrap(); + tts.extend(quote!( + impl sqlx::types::HasSqlType< #ident > for sqlx::Postgres { + fn type_info() -> Self::TypeInfo { + sqlx::postgres::PgTypeInfo::with_oid(#oid) + } + } + )); + } + + Ok(tts) +} diff --git a/sqlx-macros/src/derives/mod.rs b/sqlx-macros/src/derives/mod.rs new file mode 100644 index 00000000..28d9eee7 --- /dev/null +++ b/sqlx-macros/src/derives/mod.rs @@ -0,0 +1,25 @@ +mod attributes; +mod decode; +mod encode; +mod has_sql_type; + +pub(crate) use decode::expand_derive_decode; +pub(crate) use encode::expand_derive_encode; +pub(crate) use has_sql_type::expand_derive_has_sql_type; + +use std::iter::FromIterator; +use syn::DeriveInput; + +pub(crate) fn expand_derive_type(input: &DeriveInput) -> syn::Result { + let encode_tts = expand_derive_encode(input)?; + let decode_tts = expand_derive_decode(input)?; + let has_sql_type_tts = expand_derive_has_sql_type(input)?; + + let combined = proc_macro2::TokenStream::from_iter( + encode_tts + .into_iter() + .chain(decode_tts) + .chain(has_sql_type_tts), + ); + Ok(combined) +} From 4cd179d42b0bbb72569f44bbeccd097978a737c6 Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Mon, 10 Feb 2020 11:43:37 +0100 Subject: [PATCH 06/25] move decode_struct_field and encode_struct_field to sqlx-core --- sqlx-core/src/postgres/mod.rs | 2 +- sqlx-core/src/postgres/types/mod.rs | 1 + sqlx-core/src/postgres/types/struct.rs | 59 ++++++++++++++++++++++++++ sqlx-macros/src/derives/decode.rs | 27 ++---------- sqlx-macros/src/derives/encode.rs | 26 +++--------- 5 files changed, 71 insertions(+), 44 deletions(-) create mode 100644 sqlx-core/src/postgres/types/struct.rs diff --git a/sqlx-core/src/postgres/mod.rs b/sqlx-core/src/postgres/mod.rs index f1dd5997..d5f80737 100644 --- a/sqlx-core/src/postgres/mod.rs +++ b/sqlx-core/src/postgres/mod.rs @@ -29,4 +29,4 @@ pub type PgPool = crate::pool::Pool; make_query_as!(PgQueryAs, Postgres, PgRow); impl_map_row_for_row!(Postgres, PgRow); impl_column_index_for_row!(Postgres); -impl_from_row_for_tuples!(Postgres, PgRow); +impl_from_row_for_tuples!(Postgres, PgRow); \ No newline at end of file diff --git a/sqlx-core/src/postgres/types/mod.rs b/sqlx-core/src/postgres/types/mod.rs index ab676e9a..91d1acc3 100644 --- a/sqlx-core/src/postgres/types/mod.rs +++ b/sqlx-core/src/postgres/types/mod.rs @@ -12,6 +12,7 @@ mod bytes; mod float; mod int; mod str; +pub mod r#struct; #[cfg(feature = "chrono")] mod chrono; diff --git a/sqlx-core/src/postgres/types/struct.rs b/sqlx-core/src/postgres/types/struct.rs new file mode 100644 index 00000000..ee9a080e --- /dev/null +++ b/sqlx-core/src/postgres/types/struct.rs @@ -0,0 +1,59 @@ +use crate::decode::{Decode, DecodeError}; +use crate::encode::Encode; +use crate::postgres::protocol::TypeId; +use crate::postgres::types::PgTypeInfo; +use crate::types::HasSqlType; +use crate::Postgres; +use std::convert::TryInto; + +/// read a struct field and advance the buffer +pub fn decode_struct_field>(buf: &mut &[u8]) -> Result +where + Postgres: HasSqlType, +{ + if buf.len() < 8 { + return Err(DecodeError::Message(std::boxed::Box::new( + "Not enough data sent", + ))); + } + + let oid = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[0..4]).unwrap()); + if oid != >::type_info().oid() { + return Err(DecodeError::Message(std::boxed::Box::new("Invalid oid"))); + } + + let len = u32::from_be_bytes(buf[4..8].try_into().unwrap()) as usize; + + if buf.len() < 8 + len { + return Err(DecodeError::Message(std::boxed::Box::new( + "Not enough data sent", + ))); + } + + let raw = &buf[8..8 + len]; + let value = T::decode(raw)?; + + *buf = &buf[8 + len..]; + + Ok(value) +} + +pub fn encode_struct_field>(buf: &mut Vec, value: &T) +where + Postgres: HasSqlType, +{ + // write oid + let info = >::type_info(); + buf.extend(&info.oid().to_be_bytes()); + + // write zeros for length + buf.extend(&[0; 4]); + + let start = buf.len(); + value.encode(buf); + let end = buf.len(); + let size = end - start; + + // replaces zeros with actual length + buf[start - 4..start].copy_from_slice(&(size as u32).to_be_bytes()); +} diff --git a/sqlx-macros/src/derives/decode.rs b/sqlx-macros/src/derives/decode.rs index 2aa779f8..dc9322db 100644 --- a/sqlx-macros/src/derives/decode.rs +++ b/sqlx-macros/src/derives/decode.rs @@ -161,35 +161,16 @@ fn expand_derive_decode_struct( } let (impl_generics, _, where_clause) = generics.split_for_impl(); - let mut reads: Vec> = Vec::new(); + let mut reads: Vec = Vec::new(); let mut names: Vec = Vec::new(); for field in fields { let id = &field.ident; names.push(id.clone().unwrap()); let ty = &field.ty; reads.push(parse_quote!( - if buf.len() < 8 { - return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Not enough data sent"))); - } - - let oid = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[0..4]).unwrap()); - if oid != >::type_info().oid() { - return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid oid"))); - } - - let len = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[4..8]).unwrap()) as usize; - - if buf.len() < 8 + len { - return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Not enough data sent"))); - } - - let raw = &buf[8..8+len]; - let #id = <#ty as sqlx::decode::Decode>::decode(raw)?; - - let buf = &buf[8+len..]; - )); + let #id = sqlx::postgres::decode_struct_field::<#ty>(&mut buf)?; + )); } - let reads = reads.into_iter().flatten(); tts.extend(quote!( impl #impl_generics sqlx::decode::Decode for #ident#ty_generics #where_clause { @@ -202,7 +183,7 @@ fn expand_derive_decode_struct( if column_count != #column_count { return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid column count"))); } - let buf = &buf[4..]; + let mut buf = &buf[4..]; #(#reads)* diff --git a/sqlx-macros/src/derives/encode.rs b/sqlx-macros/src/derives/encode.rs index 03c91d00..84aa53f1 100644 --- a/sqlx-macros/src/derives/encode.rs +++ b/sqlx-macros/src/derives/encode.rs @@ -6,8 +6,8 @@ use quote::quote; use syn::punctuated::Punctuated; use syn::token::Comma; use syn::{ - parse_quote, Block, Data, DataEnum, DataStruct, DeriveInput, Expr, Field, Fields, FieldsNamed, - FieldsUnnamed, Variant, + parse_quote, Data, DataEnum, DataStruct, DeriveInput, Expr, Field, Fields, FieldsNamed, + FieldsUnnamed, Stmt, Variant, }; pub fn expand_derive_encode(input: &DeriveInput) -> syn::Result { @@ -160,26 +160,12 @@ fn expand_derive_encode_struct( } let (impl_generics, _, where_clause) = generics.split_for_impl(); - let mut writes: Vec = Vec::new(); + let mut writes: Vec = Vec::new(); for field in fields { let id = &field.ident; - let ty = &field.ty; - writes.push(parse_quote!({ - // write oid - let info = >::type_info(); - buf.extend(&info.oid().to_be_bytes()); - - // write zeros for length - buf.extend(&[0; 4]); - - let start = buf.len(); - sqlx::encode::Encode::::encode(&self. #id, buf); - let end = buf.len(); - let size = end - start; - - // replaces zeros with actual length - buf[start-4..start].copy_from_slice(&(size as u32).to_be_bytes()); - })); + writes.push(parse_quote!( + sqlx::postgres::encode_struct_field(buf, &self. #id); + )); } let mut sizes: Vec = Vec::new(); From 7185f1ff2533137c803d6cd0fb10adee9e868f55 Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Mon, 10 Feb 2020 11:51:41 +0100 Subject: [PATCH 07/25] switch from vecs to iterator chains --- sqlx-macros/src/derives/decode.rs | 14 ++++++-------- sqlx-macros/src/derives/encode.rs | 20 +++++++++----------- 2 files changed, 15 insertions(+), 19 deletions(-) diff --git a/sqlx-macros/src/derives/decode.rs b/sqlx-macros/src/derives/decode.rs index dc9322db..fa4604ef 100644 --- a/sqlx-macros/src/derives/decode.rs +++ b/sqlx-macros/src/derives/decode.rs @@ -2,7 +2,6 @@ use super::attributes::{ check_strong_enum_attributes, check_struct_attributes, check_transparent_attributes, check_weak_enum_attributes, parse_attributes, }; -use proc_macro2::Ident; use quote::quote; use syn::punctuated::Punctuated; use syn::token::Comma; @@ -161,16 +160,15 @@ fn expand_derive_decode_struct( } let (impl_generics, _, where_clause) = generics.split_for_impl(); - let mut reads: Vec = Vec::new(); - let mut names: Vec = Vec::new(); - for field in fields { + let reads = fields.iter().map(|field| -> Stmt { let id = &field.ident; - names.push(id.clone().unwrap()); let ty = &field.ty; - reads.push(parse_quote!( + parse_quote!( let #id = sqlx::postgres::decode_struct_field::<#ty>(&mut buf)?; - )); - } + ) + }); + + let names = fields.iter().map(|field| &field.ident); tts.extend(quote!( impl #impl_generics sqlx::decode::Decode for #ident#ty_generics #where_clause { diff --git a/sqlx-macros/src/derives/encode.rs b/sqlx-macros/src/derives/encode.rs index 84aa53f1..e5f4da45 100644 --- a/sqlx-macros/src/derives/encode.rs +++ b/sqlx-macros/src/derives/encode.rs @@ -160,22 +160,20 @@ fn expand_derive_encode_struct( } let (impl_generics, _, where_clause) = generics.split_for_impl(); - let mut writes: Vec = Vec::new(); - for field in fields { + let writes = fields.iter().map(|field| -> Stmt { let id = &field.ident; - writes.push(parse_quote!( + parse_quote!( sqlx::postgres::encode_struct_field(buf, &self. #id); - )); - } + ) + }); - let mut sizes: Vec = Vec::new(); - for field in fields { + let sizes = fields.iter().map(|field| -> Expr { let id = &field.ident; let ty = &field.ty; - sizes.push( - parse_quote!(<#ty as sqlx::encode::Encode>::size_hint(&self. #id)), - ); - } + parse_quote!( + <#ty as sqlx::encode::Encode>::size_hint(&self. #id) + ) + }); tts.extend(quote!( impl #impl_generics sqlx::encode::Encode for #ident #ty_generics #where_clause { From a2ba26dc7e17deb825dec6180d6d9518146812dd Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Mon, 10 Feb 2020 11:54:05 +0100 Subject: [PATCH 08/25] add explanation for size_hint --- sqlx-macros/src/derives/encode.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sqlx-macros/src/derives/encode.rs b/sqlx-macros/src/derives/encode.rs index e5f4da45..fc1d32c1 100644 --- a/sqlx-macros/src/derives/encode.rs +++ b/sqlx-macros/src/derives/encode.rs @@ -182,7 +182,9 @@ fn expand_derive_encode_struct( #(#writes)* } fn size_hint(&self) -> usize { - 4 + #column_count * (4 + 4) + #(#sizes)+* + 4 // oid + + #column_count * (4 + 4) // oid (int) and length (int) for each column + + #(#sizes)+* // sum of the size hints for each column } } )); From 6baddae9fdeaf012bc1064af4d438fbceb20fae1 Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Mon, 10 Feb 2020 11:58:20 +0100 Subject: [PATCH 09/25] fix error messages --- sqlx-macros/src/derives/decode.rs | 15 +++++++++++++-- sqlx-macros/src/derives/encode.rs | 15 +++++++++++++-- sqlx-macros/src/derives/has_sql_type.rs | 15 +++++++++++++-- 3 files changed, 39 insertions(+), 6 deletions(-) diff --git a/sqlx-macros/src/derives/decode.rs b/sqlx-macros/src/derives/decode.rs index fa4604ef..23c6427d 100644 --- a/sqlx-macros/src/derives/decode.rs +++ b/sqlx-macros/src/derives/decode.rs @@ -27,9 +27,20 @@ pub fn expand_derive_decode(input: &DeriveInput) -> syn::Result expand_derive_decode_struct(input, named), - _ => Err(syn::Error::new_spanned( + Data::Union(_) => Err(syn::Error::new_spanned(input, "unions are not supported")), + Data::Struct(DataStruct { + fields: Fields::Unnamed(..), + .. + }) => Err(syn::Error::new_spanned( input, - "expected a tuple struct with a single field", + "structs with zero or more than one unnamed field are not supported", + )), + Data::Struct(DataStruct { + fields: Fields::Unit, + .. + }) => Err(syn::Error::new_spanned( + input, + "unit structs are not supported", )), } } diff --git a/sqlx-macros/src/derives/encode.rs b/sqlx-macros/src/derives/encode.rs index fc1d32c1..1e0c2caa 100644 --- a/sqlx-macros/src/derives/encode.rs +++ b/sqlx-macros/src/derives/encode.rs @@ -28,9 +28,20 @@ pub fn expand_derive_encode(input: &DeriveInput) -> syn::Result expand_derive_encode_struct(input, named), - _ => Err(syn::Error::new_spanned( + Data::Union(_) => Err(syn::Error::new_spanned(input, "unions are not supported")), + Data::Struct(DataStruct { + fields: Fields::Unnamed(..), + .. + }) => Err(syn::Error::new_spanned( input, - "expected a tuple struct with a single field", + "structs with zero or more than one unnamed field are not supported", + )), + Data::Struct(DataStruct { + fields: Fields::Unit, + .. + }) => Err(syn::Error::new_spanned( + input, + "unit structs are not supported", )), } } diff --git a/sqlx-macros/src/derives/has_sql_type.rs b/sqlx-macros/src/derives/has_sql_type.rs index 64da9ae1..8251d500 100644 --- a/sqlx-macros/src/derives/has_sql_type.rs +++ b/sqlx-macros/src/derives/has_sql_type.rs @@ -27,9 +27,20 @@ pub fn expand_derive_has_sql_type(input: &DeriveInput) -> syn::Result expand_derive_has_sql_type_struct(input, named), - _ => Err(syn::Error::new_spanned( + Data::Union(_) => Err(syn::Error::new_spanned(input, "unions are not supported")), + Data::Struct(DataStruct { + fields: Fields::Unnamed(..), + .. + }) => Err(syn::Error::new_spanned( input, - "expected a tuple struct with a single field", + "structs with zero or more than one unnamed field are not supported", + )), + Data::Struct(DataStruct { + fields: Fields::Unit, + .. + }) => Err(syn::Error::new_spanned( + input, + "unit structs are not supported", )), } } From a600b5b85628ae7eb4e0b43cec0a2f74f71c616d Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Mon, 10 Feb 2020 12:16:00 +0100 Subject: [PATCH 10/25] add tests for postgres struct field encoding --- Cargo.toml | 4 ++++ tests/postgres-struct.rs | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+) create mode 100644 tests/postgres-struct.rs diff --git a/Cargo.toml b/Cargo.toml index 736f49cd..be794991 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -112,5 +112,9 @@ required-features = [ "mysql" ] name = "derives" required-features = [ "macros" ] +[[test]] +name = "postgres-struct" +required-features = [ "postgres" ] + [profile.release] lto = true diff --git a/tests/postgres-struct.rs b/tests/postgres-struct.rs new file mode 100644 index 00000000..d1083431 --- /dev/null +++ b/tests/postgres-struct.rs @@ -0,0 +1,39 @@ +use sqlx::encode::Encode; +use sqlx::postgres::{decode_struct_field, encode_struct_field}; +use sqlx::types::HasSqlType; +use sqlx::Postgres; +use std::convert::TryInto; + +#[test] +fn test_encode_field() { + let value = "Foo Bar"; + let mut raw_encoded = Vec::new(); + <&str as Encode>::encode(&value, &mut raw_encoded); + let mut field_encoded = Vec::new(); + encode_struct_field(&mut field_encoded, &value); + + // check oid + let oid = >::type_info().oid(); + let field_encoded_oid = u32::from_be_bytes(field_encoded[0..4].try_into().unwrap()); + assert_eq!(oid, field_encoded_oid); + + // check length + let field_encoded_length = u32::from_be_bytes(field_encoded[4..8].try_into().unwrap()); + assert_eq!(raw_encoded.len(), field_encoded_length as usize); + + // check data + assert_eq!(raw_encoded, &field_encoded[8..]); +} + +#[test] +fn test_decode_field() { + let value = "Foo Bar".to_string(); + + let mut buf = Vec::new(); + encode_struct_field(&mut buf, &value); + + let mut buf = buf.as_slice(); + let value_decoded: String = decode_struct_field(&mut buf).unwrap(); + assert_eq!(value_decoded, value); + assert!(buf.is_empty()); +} From 8841f83e6815589ac19bb8b1544130911b73f2d5 Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Mon, 10 Feb 2020 12:20:17 +0100 Subject: [PATCH 11/25] removed unused imports --- sqlx-core/src/postgres/types/struct.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/sqlx-core/src/postgres/types/struct.rs b/sqlx-core/src/postgres/types/struct.rs index ee9a080e..e07f5c38 100644 --- a/sqlx-core/src/postgres/types/struct.rs +++ b/sqlx-core/src/postgres/types/struct.rs @@ -1,7 +1,5 @@ use crate::decode::{Decode, DecodeError}; use crate::encode::Encode; -use crate::postgres::protocol::TypeId; -use crate::postgres::types::PgTypeInfo; use crate::types::HasSqlType; use crate::Postgres; use std::convert::TryInto; From 62b591e63a656233c6067a1042c19e1381a85ab6 Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Mon, 10 Feb 2020 12:38:05 +0100 Subject: [PATCH 12/25] use iterator change in expand_derive_strong_enum --- sqlx-macros/src/derives/decode.rs | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/sqlx-macros/src/derives/decode.rs b/sqlx-macros/src/derives/decode.rs index 23c6427d..93f1adae 100644 --- a/sqlx-macros/src/derives/decode.rs +++ b/sqlx-macros/src/derives/decode.rs @@ -118,17 +118,16 @@ fn expand_derive_decode_strong_enum( let ident = &input.ident; - let mut value_arms = Vec::new(); - for v in variants { + let value_arms = variants.iter().map(|v| -> Arm { let id = &v.ident; - let attributes = parse_attributes(&v.attrs)?; + let attributes = parse_attributes(&v.attrs).unwrap(); if let Some(rename) = attributes.rename { - value_arms.push(quote!(#rename => Ok(#ident :: #id),)); + parse_quote!(#rename => Ok(#ident :: #id),) } else { let name = id.to_string(); - value_arms.push(quote!(#name => Ok(#ident :: #id),)); + parse_quote!(#name => Ok(#ident :: #id),) } - } + }); // TODO: prevent heap allocation Ok(quote!( From c76b3147d5b74300c5367e842cda5ef88a5cf116 Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Tue, 17 Mar 2020 01:11:00 -0700 Subject: [PATCH 13/25] remove profile config from Cargo.toml --- Cargo.toml | 3 --- 1 file changed, 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index be794991..6b5ab80d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -115,6 +115,3 @@ required-features = [ "macros" ] [[test]] name = "postgres-struct" required-features = [ "postgres" ] - -[profile.release] -lto = true From ced6713f5773a63d012e0b279447edd9bcfcdb50 Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Tue, 17 Mar 2020 01:11:15 -0700 Subject: [PATCH 14/25] add trailing newline --- sqlx-core/src/mysql/mod.rs | 2 +- sqlx-core/src/postgres/mod.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sqlx-core/src/mysql/mod.rs b/sqlx-core/src/mysql/mod.rs index 1cd9cd2a..7efddc38 100644 --- a/sqlx-core/src/mysql/mod.rs +++ b/sqlx-core/src/mysql/mod.rs @@ -29,4 +29,4 @@ pub type MySqlPool = crate::pool::Pool; make_query_as!(MySqlQueryAs, MySql, MySqlRow); impl_map_row_for_row!(MySql, MySqlRow); impl_column_index_for_row!(MySql); -impl_from_row_for_tuples!(MySql, MySqlRow); \ No newline at end of file +impl_from_row_for_tuples!(MySql, MySqlRow); diff --git a/sqlx-core/src/postgres/mod.rs b/sqlx-core/src/postgres/mod.rs index d5f80737..f1dd5997 100644 --- a/sqlx-core/src/postgres/mod.rs +++ b/sqlx-core/src/postgres/mod.rs @@ -29,4 +29,4 @@ pub type PgPool = crate::pool::Pool; make_query_as!(PgQueryAs, Postgres, PgRow); impl_map_row_for_row!(Postgres, PgRow); impl_column_index_for_row!(Postgres); -impl_from_row_for_tuples!(Postgres, PgRow); \ No newline at end of file +impl_from_row_for_tuples!(Postgres, PgRow); From a5d17eab00b3db60d3d8415032e7d5fd3ae1b60f Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Tue, 17 Mar 2020 01:13:15 -0700 Subject: [PATCH 15/25] add derive(Debug) for PgValue --- sqlx-core/src/postgres/row.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/sqlx-core/src/postgres/row.rs b/sqlx-core/src/postgres/row.rs index 9158b2a7..c9224795 100644 --- a/sqlx-core/src/postgres/row.rs +++ b/sqlx-core/src/postgres/row.rs @@ -11,6 +11,7 @@ use crate::row::{ColumnIndex, Row}; /// A value from Postgres. This may be in a BINARY or TEXT format depending /// on the data type and if the query was prepared or not. +#[derive(Debug)] pub enum PgValue<'c> { Binary(&'c [u8]), Text(&'c str), From 4e7b1b51e0ac871965acadd3ca7633c4670906f7 Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Tue, 17 Mar 2020 01:13:32 -0700 Subject: [PATCH 16/25] sqlite: handle encoding nulls --- sqlx-core/src/sqlite/arguments.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sqlx-core/src/sqlite/arguments.rs b/sqlx-core/src/sqlite/arguments.rs index 28099fef..8d4c6618 100644 --- a/sqlx-core/src/sqlite/arguments.rs +++ b/sqlx-core/src/sqlite/arguments.rs @@ -9,7 +9,7 @@ use libsqlite3_sys::{ }; use crate::arguments::Arguments; -use crate::encode::Encode; +use crate::encode::{Encode, IsNull}; use crate::sqlite::statement::Statement; use crate::sqlite::Sqlite; use crate::sqlite::SqliteError; @@ -63,7 +63,9 @@ impl Arguments for SqliteArguments { where T: Encode + Type, { - value.encode(&mut self.values); + if let IsNull::Yes = value.encode_nullable(&mut self.values) { + self.values.push(SqliteArgumentValue::Null); + } } } From 602e61ab2725f70128c8fd3f39d30e1a0283cb8e Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Tue, 17 Mar 2020 02:44:30 -0700 Subject: [PATCH 17/25] postgres: add support for decoding anonymous tuples and more fully test encoding/decoding records --- Cargo.toml | 4 - sqlx-core/src/postgres/mod.rs | 2 +- sqlx-core/src/postgres/types/mod.rs | 4 +- sqlx-core/src/postgres/types/record.rs | 328 +++++++++++++++++++++++++ sqlx-core/src/postgres/types/struct.rs | 57 ----- tests/postgres-struct.rs | 39 --- tests/postgres-types.rs | 195 ++++++++++++++- 7 files changed, 525 insertions(+), 104 deletions(-) create mode 100644 sqlx-core/src/postgres/types/record.rs delete mode 100644 sqlx-core/src/postgres/types/struct.rs delete mode 100644 tests/postgres-struct.rs diff --git a/Cargo.toml b/Cargo.toml index 6b5ab80d..e3cd32fa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -111,7 +111,3 @@ required-features = [ "mysql" ] [[test]] name = "derives" required-features = [ "macros" ] - -[[test]] -name = "postgres-struct" -required-features = [ "postgres" ] diff --git a/sqlx-core/src/postgres/mod.rs b/sqlx-core/src/postgres/mod.rs index f1dd5997..853514ac 100644 --- a/sqlx-core/src/postgres/mod.rs +++ b/sqlx-core/src/postgres/mod.rs @@ -21,7 +21,7 @@ mod row; mod sasl; mod stream; mod tls; -mod types; +pub mod types; /// An alias for [`Pool`][crate::Pool], specialized for **Postgres**. pub type PgPool = crate::pool::Pool; diff --git a/sqlx-core/src/postgres/types/mod.rs b/sqlx-core/src/postgres/types/mod.rs index 91d1acc3..ab202c40 100644 --- a/sqlx-core/src/postgres/types/mod.rs +++ b/sqlx-core/src/postgres/types/mod.rs @@ -11,8 +11,10 @@ mod bool; mod bytes; mod float; mod int; +mod record; mod str; -pub mod r#struct; + +pub use self::record::{PgRecordDecoder, PgRecordEncoder}; #[cfg(feature = "chrono")] mod chrono; diff --git a/sqlx-core/src/postgres/types/record.rs b/sqlx-core/src/postgres/types/record.rs new file mode 100644 index 00000000..0e40f1ef --- /dev/null +++ b/sqlx-core/src/postgres/types/record.rs @@ -0,0 +1,328 @@ +use crate::decode::Decode; +use crate::encode::Encode; +use crate::io::Buf; +use crate::postgres::protocol::TypeId; +use crate::postgres::{PgTypeInfo, PgValue, Postgres}; +use crate::types::Type; +use byteorder::BigEndian; +use std::convert::TryInto; + +pub struct PgRecordEncoder<'a> { + buf: &'a mut Vec, + beg: usize, + num: u32, +} + +impl<'a> PgRecordEncoder<'a> { + pub fn new(buf: &'a mut Vec) -> Self { + // reserve space for a field count + buf.extend_from_slice(&(0_u32).to_be_bytes()); + + Self { + beg: buf.len(), + buf, + num: 0, + } + } + + pub fn finish(&mut self) { + // replaces zeros with actual length + self.buf[self.beg - 4..self.beg].copy_from_slice(&self.num.to_be_bytes()); + } + + pub fn encode(&mut self, value: T) -> &mut Self + where + T: Type + Encode, + { + // write oid + let info = T::type_info(); + self.buf.extend(&info.oid().to_be_bytes()); + + // write zeros for length + self.buf.extend(&[0; 4]); + + let start = self.buf.len(); + value.encode(self.buf); + + let end = self.buf.len(); + let size = end - start; + + // replaces zeros with actual length + self.buf[start - 4..start].copy_from_slice(&(size as u32).to_be_bytes()); + + // keep track of count + self.num += 1; + + self + } +} + +// impl Encode for (bool, i32, i64, f64, String) { +// fn encode(&self, buf: &mut Vec) { +// PgRecordEncoder::new(buf) +// .encode(self.0) +// .encode(self.1) +// .encode(self.2) +// .encode(self.3) +// .encode(&self.4) +// .finish() +// } +// +// fn size_hint(&self) -> usize { +// // for each field; oid, length, value +// 5 * (4 + 4) +// + (>::size_hint(&self.0) +// + >::size_hint(&self.1) +// + >::size_hint(&self.2) +// + >::size_hint(&self.3) +// + >::size_hint(&self.4)) +// } +// } + +pub struct PgRecordDecoder<'de> { + value: PgValue<'de>, +} + +impl<'de> PgRecordDecoder<'de> { + pub fn new(value: Option>) -> crate::Result { + let mut value: PgValue = value.try_into()?; + + match value { + PgValue::Binary(ref mut buf) => { + let _expected_len = buf.get_u32::()?; + } + + PgValue::Text(ref mut s) => { + // remove outer ( ... ) + *s = &s[1..(s.len() - 1)]; + } + } + + Ok(Self { value }) + } + + pub fn decode(&mut self) -> crate::Result + where + T: Decode<'de, Postgres>, + { + match self.value { + PgValue::Binary(ref mut buf) => { + // TODO: We should fail if this type is not _compatible_; but + // I want to make sure we handle this _and_ the outer level + // type mismatch errors at the same time + let _oid = buf.get_u32::()?; + let len = buf.get_i32::()? as isize; + + let value = if len < 0 { + T::decode(None)? + } else { + let value_buf = &buf[..(len as usize)]; + *buf = &buf[(len as usize)..]; + + T::decode(Some(PgValue::Binary(value_buf)))? + }; + + Ok(value) + } + + PgValue::Text(ref mut s) => { + let mut in_quotes = false; + let mut in_escape = false; + let mut is_quoted = false; + let mut prev_ch = '\0'; + let mut eos = false; + let mut prev_index = 0; + let mut value = String::new(); + + let index = 'outer: loop { + let mut iter = s.char_indices(); + while let Some((index, ch)) = iter.next() { + match ch { + ',' if !in_quotes => { + break 'outer Some(prev_index); + } + + ',' if prev_ch == '\0' => { + break 'outer None; + } + + '"' if prev_ch == '"' && index != 1 => { + // Quotes are escaped with another quote + in_quotes = false; + value.push('"'); + } + + '"' if in_quotes => { + in_quotes = false; + } + + '\'' if in_escape => { + in_escape = false; + value.push('\''); + } + + '"' if in_escape => { + in_escape = false; + value.push('"'); + } + + '\\' if in_escape => { + in_escape = false; + value.push('\\'); + } + + '\\' => { + in_escape = true; + } + + '"' => { + is_quoted = true; + in_quotes = true; + } + + ch => { + value.push(ch); + } + } + + prev_index = index; + prev_ch = ch; + } + + eos = true; + + break 'outer if prev_ch == '\0' { + // NULL values have zero characters + // Empty strings are "" + None + } else { + Some(prev_index) + }; + }; + + let value = index.map(|index| { + let mut s = &s[..=index]; + + if is_quoted { + s = &s[1..s.len() - 1]; + } + + PgValue::Text(s) + }); + + let value = T::decode(value)?; + + if !eos { + *s = &s[index.unwrap_or(0) + 2..]; + } else { + *s = ""; + } + + Ok(value) + } + } + } +} + +macro_rules! impl_pg_record_for_tuple { + ($( $idx:ident : $T:ident ),+) => { + impl<$($T,)+> Type for ($($T,)+) { + #[inline] + fn type_info() -> PgTypeInfo { + PgTypeInfo { + id: TypeId(2249), + name: Some("RECORD".into()), + } + } + } + + impl<'de, $($T,)+> Decode<'de, Postgres> for ($($T,)+) + where + $($T: crate::types::Type,)+ + $($T: crate::decode::Decode<'de, Postgres>,)+ + { + fn decode(value: Option>) -> crate::Result { + let mut decoder = PgRecordDecoder::new(value)?; + + $(let $idx: $T = decoder.decode()?;)+ + + Ok(($($idx,)+)) + } + } + }; +} + +impl_pg_record_for_tuple!(_1: T1); + +impl_pg_record_for_tuple!(_1: T1, _2: T2); + +impl_pg_record_for_tuple!(_1: T1, _2: T2, _3: T3); + +impl_pg_record_for_tuple!(_1: T1, _2: T2, _3: T3, _4: T4); + +impl_pg_record_for_tuple!(_1: T1, _2: T2, _3: T3, _4: T4, _5: T5); + +impl_pg_record_for_tuple!(_1: T1, _2: T2, _3: T3, _4: T4, _5: T5, _6: T6); + +impl_pg_record_for_tuple!(_1: T1, _2: T2, _3: T3, _4: T4, _5: T5, _6: T6, _7: T7); + +impl_pg_record_for_tuple!( + _1: T1, + _2: T2, + _3: T3, + _4: T4, + _5: T5, + _6: T6, + _7: T7, + _8: T8 +); + +impl_pg_record_for_tuple!( + _1: T1, + _2: T2, + _3: T3, + _4: T4, + _5: T5, + _6: T6, + _7: T7, + _8: T8, + _9: T9 +); + +#[test] +fn test_encode_field() { + let value = "Foo Bar"; + let mut raw_encoded = Vec::new(); + <&str as Encode>::encode(&value, &mut raw_encoded); + let mut field_encoded = Vec::new(); + + let mut encoder = PgRecordEncoder::new(&mut field_encoded); + encoder.encode(&value); + + // check oid + let oid = <&str as Type>::type_info().oid(); + let field_encoded_oid = u32::from_be_bytes(field_encoded[4..8].try_into().unwrap()); + assert_eq!(oid, field_encoded_oid); + + // check length + let field_encoded_length = u32::from_be_bytes(field_encoded[8..12].try_into().unwrap()); + assert_eq!(raw_encoded.len(), field_encoded_length as usize); + + // check data + assert_eq!(raw_encoded, &field_encoded[12..]); +} + +#[test] +fn test_decode_field() { + let value = "Foo Bar".to_string(); + + let mut buf = Vec::new(); + let mut encoder = PgRecordEncoder::new(&mut buf); + encoder.encode(&value); + + let mut buf = buf.as_slice(); + let mut decoder = PgRecordDecoder::new(Some(PgValue::Binary(buf))).unwrap(); + + let value_decoded: String = decoder.decode().unwrap(); + assert_eq!(value_decoded, value); +} diff --git a/sqlx-core/src/postgres/types/struct.rs b/sqlx-core/src/postgres/types/struct.rs deleted file mode 100644 index e07f5c38..00000000 --- a/sqlx-core/src/postgres/types/struct.rs +++ /dev/null @@ -1,57 +0,0 @@ -use crate::decode::{Decode, DecodeError}; -use crate::encode::Encode; -use crate::types::HasSqlType; -use crate::Postgres; -use std::convert::TryInto; - -/// read a struct field and advance the buffer -pub fn decode_struct_field>(buf: &mut &[u8]) -> Result -where - Postgres: HasSqlType, -{ - if buf.len() < 8 { - return Err(DecodeError::Message(std::boxed::Box::new( - "Not enough data sent", - ))); - } - - let oid = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[0..4]).unwrap()); - if oid != >::type_info().oid() { - return Err(DecodeError::Message(std::boxed::Box::new("Invalid oid"))); - } - - let len = u32::from_be_bytes(buf[4..8].try_into().unwrap()) as usize; - - if buf.len() < 8 + len { - return Err(DecodeError::Message(std::boxed::Box::new( - "Not enough data sent", - ))); - } - - let raw = &buf[8..8 + len]; - let value = T::decode(raw)?; - - *buf = &buf[8 + len..]; - - Ok(value) -} - -pub fn encode_struct_field>(buf: &mut Vec, value: &T) -where - Postgres: HasSqlType, -{ - // write oid - let info = >::type_info(); - buf.extend(&info.oid().to_be_bytes()); - - // write zeros for length - buf.extend(&[0; 4]); - - let start = buf.len(); - value.encode(buf); - let end = buf.len(); - let size = end - start; - - // replaces zeros with actual length - buf[start - 4..start].copy_from_slice(&(size as u32).to_be_bytes()); -} diff --git a/tests/postgres-struct.rs b/tests/postgres-struct.rs deleted file mode 100644 index d1083431..00000000 --- a/tests/postgres-struct.rs +++ /dev/null @@ -1,39 +0,0 @@ -use sqlx::encode::Encode; -use sqlx::postgres::{decode_struct_field, encode_struct_field}; -use sqlx::types::HasSqlType; -use sqlx::Postgres; -use std::convert::TryInto; - -#[test] -fn test_encode_field() { - let value = "Foo Bar"; - let mut raw_encoded = Vec::new(); - <&str as Encode>::encode(&value, &mut raw_encoded); - let mut field_encoded = Vec::new(); - encode_struct_field(&mut field_encoded, &value); - - // check oid - let oid = >::type_info().oid(); - let field_encoded_oid = u32::from_be_bytes(field_encoded[0..4].try_into().unwrap()); - assert_eq!(oid, field_encoded_oid); - - // check length - let field_encoded_length = u32::from_be_bytes(field_encoded[4..8].try_into().unwrap()); - assert_eq!(raw_encoded.len(), field_encoded_length as usize); - - // check data - assert_eq!(raw_encoded, &field_encoded[8..]); -} - -#[test] -fn test_decode_field() { - let value = "Foo Bar".to_string(); - - let mut buf = Vec::new(); - encode_struct_field(&mut buf, &value); - - let mut buf = buf.as_slice(); - let value_decoded: String = decode_struct_field(&mut buf).unwrap(); - assert_eq!(value_decoded, value); - assert!(buf.is_empty()); -} diff --git a/tests/postgres-types.rs b/tests/postgres-types.rs index 991a5bf6..46b52427 100644 --- a/tests/postgres-types.rs +++ b/tests/postgres-types.rs @@ -1,5 +1,11 @@ -use sqlx::Postgres; -use sqlx_test::test_type; +use sqlx::decode::Decode; +use sqlx::encode::Encode; +use sqlx::postgres::types::PgRecordEncoder; +use sqlx::postgres::{PgQueryAs, PgTypeInfo, PgValue}; +use sqlx::{Cursor, Executor, Postgres, Row, Type}; +use sqlx_core::postgres::types::PgRecordDecoder; +use sqlx_test::{new, test_type}; +use std::sync::atomic::{AtomicU32, Ordering}; test_type!(null( Postgres, @@ -87,3 +93,188 @@ mod chrono { ) )); } + +#[cfg_attr(feature = "runtime-async-std", async_std::test)] +#[cfg_attr(feature = "runtime-tokio", tokio::test)] +async fn test_prepared_anonymous_record() -> anyhow::Result<()> { + let mut conn = new::().await?; + + // Tuple of no elements is not possible + // Tuple of 1 element requires a concrete type + // Tuple with a NULL requires a concrete type + + // Tuple of 2 elements + let rec: ((bool, i32),) = sqlx::query_as("SELECT (true, 23512)") + .fetch_one(&mut conn) + .await?; + + assert_eq!((rec.0).0, true); + assert_eq!((rec.0).1, 23512); + + // Tuple with an empty string + let rec: ((bool, String),) = sqlx::query_as("SELECT (true,'')") + .fetch_one(&mut conn) + .await?; + + assert_eq!((rec.0).1, ""); + + // Tuple with a string with an interior comma + let rec: ((bool, String),) = sqlx::query_as("SELECT (true,'Hello, World!')") + .fetch_one(&mut conn) + .await?; + + assert_eq!((rec.0).1, "Hello, World!"); + + // Tuple with a string with an emoji + let rec: ((bool, String),) = sqlx::query_as("SELECT (true,'Hello, 🐕!')") + .fetch_one(&mut conn) + .await?; + + assert_eq!((rec.0).1, "Hello, 🐕!"); + + Ok(()) +} + +#[cfg_attr(feature = "runtime-async-std", async_std::test)] +#[cfg_attr(feature = "runtime-tokio", tokio::test)] +async fn test_unprepared_anonymous_record() -> anyhow::Result<()> { + let mut conn = new::().await?; + + // Tuple of no elements is not possible + // Tuple of 1 element requires a concrete type + // Tuple with a NULL requires a concrete type + + // Tuple of 2 elements + let mut cursor = conn.fetch("SELECT (true, 23512)"); + let row = cursor.next().await?.unwrap(); + let rec: (bool, i32) = row.get(0); + + assert_eq!(rec.0, true); + assert_eq!(rec.1, 23512); + + // Tuple with an empty string + let mut cursor = conn.fetch("SELECT (true, '')"); + let row = cursor.next().await?.unwrap(); + let rec: (bool, String) = row.get(0); + + assert_eq!(rec.1, ""); + + // Tuple with a string with an interior comma + let mut cursor = conn.fetch("SELECT (true, 'Hello, World!')"); + let row = cursor.next().await?.unwrap(); + let rec: (bool, String) = row.get(0); + + assert_eq!(rec.1, "Hello, World!"); + + // Tuple with a string with an emoji + let mut cursor = conn.fetch("SELECT (true, 'Hello, 🐕!')"); + let row = cursor.next().await?.unwrap(); + let rec: (bool, String) = row.get(0); + + assert_eq!(rec.1, "Hello, 🐕!"); + + Ok(()) +} + +#[cfg_attr(feature = "runtime-async-std", async_std::test)] +#[cfg_attr(feature = "runtime-tokio", tokio::test)] +async fn test_prepared_structs() -> anyhow::Result<()> { + let mut conn = new::().await?; + + // + // Setup custom types if needed + // + + static OID_RECORD_EMPTY: AtomicU32 = AtomicU32::new(0); + static OID_RECORD_1: AtomicU32 = AtomicU32::new(0); + + conn.execute( + r#" +DO $$ BEGIN + CREATE TYPE _sqlx_record_empty AS (); + CREATE TYPE _sqlx_record_1 AS (_1 int8); +EXCEPTION + WHEN duplicate_object THEN null; +END $$; + "#, + ) + .await?; + + let type_ids: Vec<(i32,)> = sqlx::query_as( + "SELECT oid::int4 FROM pg_type WHERE typname IN ('_sqlx_record_empty', '_sqlx_record_1')", + ) + .fetch_all(&mut conn) + .await?; + + OID_RECORD_EMPTY.store(type_ids[0].0 as u32, Ordering::SeqCst); + OID_RECORD_1.store(type_ids[1].0 as u32, Ordering::SeqCst); + + // + // Record of no elements + // + + struct RecordEmpty {} + + impl Type for RecordEmpty { + fn type_info() -> PgTypeInfo { + PgTypeInfo::with_oid(OID_RECORD_EMPTY.load(Ordering::SeqCst)) + } + } + + impl Encode for RecordEmpty { + fn encode(&self, buf: &mut Vec) { + PgRecordEncoder::new(buf).finish(); + } + } + + impl<'de> Decode<'de, Postgres> for RecordEmpty { + fn decode(_value: Option>) -> sqlx::Result { + Ok(RecordEmpty {}) + } + } + + let _: (RecordEmpty, RecordEmpty) = sqlx::query_as("SELECT '()'::_sqlx_record_empty, $1") + .bind(RecordEmpty {}) + .fetch_one(&mut conn) + .await?; + + // + // Record of one element + // + + #[derive(Debug, PartialEq)] + struct Record1 { + _1: i64, + } + + impl Type for Record1 { + fn type_info() -> PgTypeInfo { + PgTypeInfo::with_oid(OID_RECORD_1.load(Ordering::SeqCst)) + } + } + + impl Encode for Record1 { + fn encode(&self, buf: &mut Vec) { + PgRecordEncoder::new(buf).encode(self._1).finish(); + } + } + + impl<'de> Decode<'de, Postgres> for Record1 { + fn decode(value: Option>) -> sqlx::Result { + let mut decoder = PgRecordDecoder::new(value)?; + + let _1 = decoder.decode()?; + + Ok(Record1 { _1 }) + } + } + + let rec: (Record1, Record1) = sqlx::query_as("SELECT '(324235)'::_sqlx_record_1, $1") + .bind(Record1 { _1: 324235 }) + .fetch_one(&mut conn) + .await?; + + assert_eq!(rec.0, rec.1); + + Ok(()) +} From 4fc5e65f5d0207d4be7e36c8348d15e0f73c7e94 Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Tue, 17 Mar 2020 03:24:08 -0700 Subject: [PATCH 18/25] derives: update transparent --- sqlx-macros/src/derives/decode.rs | 23 +- sqlx-macros/src/derives/encode.rs | 4 +- sqlx-macros/src/derives/mod.rs | 16 +- .../src/derives/{has_sql_type.rs => type.rs} | 44 +- sqlx-macros/src/lib.rs | 11 +- tests/derives.rs | 497 +++++++++--------- 6 files changed, 296 insertions(+), 299 deletions(-) rename sqlx-macros/src/derives/{has_sql_type.rs => type.rs} (83%) diff --git a/sqlx-macros/src/derives/decode.rs b/sqlx-macros/src/derives/decode.rs index 93f1adae..b6b818ee 100644 --- a/sqlx-macros/src/derives/decode.rs +++ b/sqlx-macros/src/derives/decode.rs @@ -61,25 +61,22 @@ fn expand_derive_decode_transparent( // add db type for impl generics & where clause let mut generics = generics.clone(); generics.params.insert(0, parse_quote!(DB: sqlx::Database)); + generics.params.insert(0, parse_quote!('de)); generics .make_where_clause() .predicates - .push(parse_quote!(#ty: sqlx::decode::Decode)); + .push(parse_quote!(#ty: sqlx::decode::Decode<'de, DB>)); let (impl_generics, _, where_clause) = generics.split_for_impl(); - Ok(quote!( - impl #impl_generics sqlx::decode::Decode for #ident #ty_generics #where_clause { - fn decode(raw: &[u8]) -> std::result::Result { - <#ty as sqlx::decode::Decode>::decode(raw).map(Self) - } - fn decode_null() -> std::result::Result { - <#ty as sqlx::decode::Decode>::decode_null().map(Self) - } - fn decode_nullable(raw: std::option::Option<&[u8]>) -> std::result::Result { - <#ty as sqlx::decode::Decode>::decode_nullable(raw).map(Self) + let tts = quote!( + impl #impl_generics sqlx::decode::Decode<'de, DB> for #ident #ty_generics #where_clause { + fn decode(value: >::RawValue) -> sqlx::Result { + <#ty as sqlx::decode::Decode<'de, DB>>::decode(value).map(Self) } } - )) + ); + + Ok(tts) } fn expand_derive_decode_weak_enum( @@ -166,7 +163,7 @@ fn expand_derive_decode_struct( for field in fields { let ty = &field.ty; predicates.push(parse_quote!(#ty: sqlx::decode::Decode)); - predicates.push(parse_quote!(sqlx::Postgres: sqlx::types::HasSqlType<#ty>)); + predicates.push(parse_quote!(#ty: sqlx::types::Type)); } let (impl_generics, _, where_clause) = generics.split_for_impl(); diff --git a/sqlx-macros/src/derives/encode.rs b/sqlx-macros/src/derives/encode.rs index 1e0c2caa..56ad9fb7 100644 --- a/sqlx-macros/src/derives/encode.rs +++ b/sqlx-macros/src/derives/encode.rs @@ -70,10 +70,10 @@ fn expand_derive_encode_transparent( Ok(quote!( impl #impl_generics sqlx::encode::Encode for #ident #ty_generics #where_clause { - fn encode(&self, buf: &mut std::vec::Vec) { + fn encode(&self, buf: &mut DB::RawBuffer) { sqlx::encode::Encode::encode(&self.0, buf) } - fn encode_nullable(&self, buf: &mut std::vec::Vec) -> sqlx::encode::IsNull { + fn encode_nullable(&self, buf: &mut DB::RawBuffer) -> sqlx::encode::IsNull { sqlx::encode::Encode::encode_nullable(&self.0, buf) } fn size_hint(&self) -> usize { diff --git a/sqlx-macros/src/derives/mod.rs b/sqlx-macros/src/derives/mod.rs index 28d9eee7..4e36533d 100644 --- a/sqlx-macros/src/derives/mod.rs +++ b/sqlx-macros/src/derives/mod.rs @@ -1,25 +1,25 @@ mod attributes; mod decode; mod encode; -mod has_sql_type; +mod r#type; pub(crate) use decode::expand_derive_decode; pub(crate) use encode::expand_derive_encode; -pub(crate) use has_sql_type::expand_derive_has_sql_type; +pub(crate) use r#type::expand_derive_type; use std::iter::FromIterator; use syn::DeriveInput; -pub(crate) fn expand_derive_type(input: &DeriveInput) -> syn::Result { +pub(crate) fn expand_derive_type_encode_decode( + input: &DeriveInput, +) -> syn::Result { let encode_tts = expand_derive_encode(input)?; let decode_tts = expand_derive_decode(input)?; - let has_sql_type_tts = expand_derive_has_sql_type(input)?; + let type_tts = expand_derive_type(input)?; let combined = proc_macro2::TokenStream::from_iter( - encode_tts - .into_iter() - .chain(decode_tts) - .chain(has_sql_type_tts), + encode_tts.into_iter().chain(decode_tts).chain(type_tts), ); + Ok(combined) } diff --git a/sqlx-macros/src/derives/has_sql_type.rs b/sqlx-macros/src/derives/type.rs similarity index 83% rename from sqlx-macros/src/derives/has_sql_type.rs rename to sqlx-macros/src/derives/type.rs index 8251d500..0c86cb8a 100644 --- a/sqlx-macros/src/derives/has_sql_type.rs +++ b/sqlx-macros/src/derives/type.rs @@ -10,7 +10,7 @@ use syn::{ FieldsUnnamed, Variant, }; -pub fn expand_derive_has_sql_type(input: &DeriveInput) -> syn::Result { +pub fn expand_derive_type(input: &DeriveInput) -> syn::Result { let attrs = parse_attributes(&input.attrs)?; match &input.data { Data::Struct(DataStruct { @@ -56,37 +56,39 @@ fn expand_derive_has_sql_type_transparent( // extract type generics let generics = &input.generics; - let (impl_generics, ty_generics, _) = generics.split_for_impl(); + let (_, ty_generics, _) = generics.split_for_impl(); // add db type for clause let mut generics = generics.clone(); + generics.params.insert(0, parse_quote!(DB: sqlx::Database)); generics .make_where_clause() .predicates - .push(parse_quote!(Self: sqlx::types::HasSqlType<#ty>)); - let (_, _, where_clause) = generics.split_for_impl(); + .push(parse_quote!(#ty: sqlx::types::Type)); + let (impl_generics, _, where_clause) = generics.split_for_impl(); let mut tts = proc_macro2::TokenStream::new(); - if cfg!(feature = "mysql") { - tts.extend(quote!( - impl #impl_generics sqlx::types::HasSqlType< #ident #ty_generics > for sqlx::MySql #where_clause { - fn type_info() -> Self::TypeInfo { - >::type_info() - } + // if cfg!(feature = "mysql") { + tts.extend(quote!( + impl #impl_generics sqlx::types::Type< DB > for #ident #ty_generics #where_clause { + fn type_info() -> DB::TypeInfo { + <#ty as sqlx::Type>::type_info() } - )); - } + } + )); - if cfg!(feature = "postgres") { - tts.extend(quote!( - impl #impl_generics sqlx::types::HasSqlType< #ident #ty_generics > for sqlx::Postgres #where_clause { - fn type_info() -> Self::TypeInfo { - >::type_info() - } - } - )); - } + // } + + // if cfg!(feature = "postgres") { + // tts.extend(quote!( + // impl #impl_generics sqlx::types::HasSqlType< sqlx::Postgres > #ident #ty_generics #where_clause { + // fn type_info() -> Self::TypeInfo { + // >::type_info() + // } + // } + // )); + // } Ok(tts) } diff --git a/sqlx-macros/src/lib.rs b/sqlx-macros/src/lib.rs index f74c775e..d0a19995 100644 --- a/sqlx-macros/src/lib.rs +++ b/sqlx-macros/src/lib.rs @@ -169,19 +169,10 @@ pub fn derive_decode(tokenstream: TokenStream) -> TokenStream { } } -#[proc_macro_derive(HasSqlType, attributes(sqlx))] -pub fn derive_has_sql_type(tokenstream: TokenStream) -> TokenStream { - let input = syn::parse_macro_input!(tokenstream as syn::DeriveInput); - match derives::expand_derive_has_sql_type(&input) { - Ok(ts) => ts.into(), - Err(e) => e.to_compile_error().into(), - } -} - #[proc_macro_derive(Type, attributes(sqlx))] pub fn derive_type(tokenstream: TokenStream) -> TokenStream { let input = syn::parse_macro_input!(tokenstream as syn::DeriveInput); - match derives::expand_derive_type(&input) { + match derives::expand_derive_type_encode_decode(&input) { Ok(ts) => ts.into(), Err(e) => e.to_compile_error().into(), } diff --git a/tests/derives.rs b/tests/derives.rs index 4c22324e..5b5bd289 100644 --- a/tests/derives.rs +++ b/tests/derives.rs @@ -1,39 +1,40 @@ use sqlx::decode::Decode; use sqlx::encode::Encode; -use sqlx::types::{HasSqlType, TypeInfo}; +use sqlx::types::TypeInfo; +use sqlx::Type; use std::fmt::Debug; -#[derive(PartialEq, Debug, Encode, Decode, HasSqlType)] +#[derive(PartialEq, Debug, Type)] #[sqlx(transparent)] struct Transparent(i32); -#[derive(PartialEq, Debug, Clone, Copy, Encode, Decode, HasSqlType)] -#[repr(i32)] -#[allow(dead_code)] -enum Weak { - One, - Two, - Three, -} - -#[derive(PartialEq, Debug, Encode, Decode, HasSqlType)] -#[sqlx(postgres(oid = 10101010))] -#[allow(dead_code)] -enum Strong { - One, - Two, - #[sqlx(rename = "four")] - Three, -} - -#[derive(PartialEq, Debug, Encode, Decode, HasSqlType)] -#[sqlx(postgres(oid = 20202020))] -#[allow(dead_code)] -struct Struct { - field1: String, - field2: i64, - field3: bool, -} +// #[derive(PartialEq, Debug, Clone, Copy, Encode, Decode, HasSqlType)] +// #[repr(i32)] +// #[allow(dead_code)] +// enum Weak { +// One, +// Two, +// Three, +// } +// +// #[derive(PartialEq, Debug, Encode, Decode, HasSqlType)] +// #[sqlx(postgres(oid = 10101010))] +// #[allow(dead_code)] +// enum Strong { +// One, +// Two, +// #[sqlx(rename = "four")] +// Three, +// } +// +// #[derive(PartialEq, Debug, Encode, Decode, HasSqlType)] +// #[sqlx(postgres(oid = 20202020))] +// #[allow(dead_code)] +// struct Struct { +// field1: String, +// field2: i64, +// field3: bool, +// } #[test] #[cfg(feature = "mysql")] @@ -48,7 +49,7 @@ fn encode_transparent_postgres() { } #[allow(dead_code)] -fn encode_transparent() +fn encode_transparent>>() where Transparent: Encode, i32: Encode, @@ -63,124 +64,124 @@ where assert_eq!(encoded, encoded_orig); } - -#[test] -#[cfg(feature = "mysql")] -fn encode_weak_enum_mysql() { - encode_weak_enum::(); -} - -#[test] -#[cfg(feature = "postgres")] -fn encode_weak_enum_postgres() { - encode_weak_enum::(); -} - -#[allow(dead_code)] -fn encode_weak_enum() -where - Weak: Encode, - i32: Encode, -{ - for example in [Weak::One, Weak::Two, Weak::Three].iter() { - let mut encoded = Vec::new(); - let mut encoded_orig = Vec::new(); - - Encode::::encode(example, &mut encoded); - Encode::::encode(&(*example as i32), &mut encoded_orig); - - assert_eq!(encoded, encoded_orig); - } -} - -#[test] -#[cfg(feature = "mysql")] -fn encode_strong_enum_mysql() { - encode_strong_enum::(); -} - -#[test] -#[cfg(feature = "postgres")] -fn encode_strong_enum_postgres() { - encode_strong_enum::(); -} - -#[allow(dead_code)] -fn encode_strong_enum() -where - Strong: Encode, - str: Encode, -{ - for (example, name) in [ - (Strong::One, "One"), - (Strong::Two, "Two"), - (Strong::Three, "four"), - ] - .iter() - { - let mut encoded = Vec::new(); - let mut encoded_orig = Vec::new(); - - Encode::::encode(example, &mut encoded); - Encode::::encode(*name, &mut encoded_orig); - - assert_eq!(encoded, encoded_orig); - } -} - -#[test] -#[cfg(feature = "postgres")] -fn encode_struct_postgres() { - let field1 = "Foo".to_string(); - let field2 = 3; - let field3 = false; - - let example = Struct { - field1: field1.clone(), - field2, - field3, - }; - - let mut encoded = Vec::new(); - Encode::::encode(&example, &mut encoded); - - let string_oid = >::type_info().oid(); - let i64_oid = >::type_info().oid(); - let bool_oid = >::type_info().oid(); - - // 3 columns - assert_eq!(&[0, 0, 0, 3], &encoded[..4]); - let encoded = &encoded[4..]; - - // check field1 (string) - assert_eq!(&string_oid.to_be_bytes(), &encoded[0..4]); - assert_eq!(&(field1.len() as u32).to_be_bytes(), &encoded[4..8]); - assert_eq!(field1.as_bytes(), &encoded[8..8 + field1.len()]); - let encoded = &encoded[8 + field1.len()..]; - - // check field2 (i64) - assert_eq!(&i64_oid.to_be_bytes(), &encoded[0..4]); - assert_eq!(&8u32.to_be_bytes(), &encoded[4..8]); - assert_eq!(field2.to_be_bytes(), &encoded[8..16]); - let encoded = &encoded[16..]; - - // check field3 (bool) - assert_eq!(&bool_oid.to_be_bytes(), &encoded[0..4]); - assert_eq!(&1u32.to_be_bytes(), &encoded[4..8]); - assert_eq!(field3, encoded[8] != 0); - let encoded = &encoded[9..]; - - assert!(encoded.is_empty()); - - let string_size = >::size_hint(&field1); - let i64_size = >::size_hint(&field2); - let bool_size = >::size_hint(&field3); - - assert_eq!( - 4 + 3 * (4 + 4) + string_size + i64_size + bool_size, - example.size_hint() - ); -} +// +// #[test] +// #[cfg(feature = "mysql")] +// fn encode_weak_enum_mysql() { +// encode_weak_enum::(); +// } +// +// #[test] +// #[cfg(feature = "postgres")] +// fn encode_weak_enum_postgres() { +// encode_weak_enum::(); +// } +// +// #[allow(dead_code)] +// fn encode_weak_enum>>() +// where +// Weak: Encode, +// i32: Encode, +// { +// for example in [Weak::One, Weak::Two, Weak::Three].iter() { +// let mut encoded = Vec::new(); +// let mut encoded_orig = Vec::new(); +// +// Encode::::encode(example, &mut encoded); +// Encode::::encode(&(*example as i32), &mut encoded_orig); +// +// assert_eq!(encoded, encoded_orig); +// } +// } +// +// #[test] +// #[cfg(feature = "mysql")] +// fn encode_strong_enum_mysql() { +// encode_strong_enum::(); +// } +// +// #[test] +// #[cfg(feature = "postgres")] +// fn encode_strong_enum_postgres() { +// encode_strong_enum::(); +// } +// +// #[allow(dead_code)] +// fn encode_strong_enum>>() +// where +// Strong: Encode, +// str: Encode, +// { +// for (example, name) in [ +// (Strong::One, "One"), +// (Strong::Two, "Two"), +// (Strong::Three, "four"), +// ] +// .iter() +// { +// let mut encoded = Vec::new(); +// let mut encoded_orig = Vec::new(); +// +// Encode::::encode(example, &mut encoded); +// Encode::::encode(*name, &mut encoded_orig); +// +// assert_eq!(encoded, encoded_orig); +// } +// } +// +// #[test] +// #[cfg(feature = "postgres")] +// fn encode_struct_postgres() { +// let field1 = "Foo".to_string(); +// let field2 = 3; +// let field3 = false; +// +// let example = Struct { +// field1: field1.clone(), +// field2, +// field3, +// }; +// +// let mut encoded = Vec::new(); +// Encode::::encode(&example, &mut encoded); +// +// let string_oid = >::type_info().oid(); +// let i64_oid = >::type_info().oid(); +// let bool_oid = >::type_info().oid(); +// +// // 3 columns +// assert_eq!(&[0, 0, 0, 3], &encoded[..4]); +// let encoded = &encoded[4..]; +// +// // check field1 (string) +// assert_eq!(&string_oid.to_be_bytes(), &encoded[0..4]); +// assert_eq!(&(field1.len() as u32).to_be_bytes(), &encoded[4..8]); +// assert_eq!(field1.as_bytes(), &encoded[8..8 + field1.len()]); +// let encoded = &encoded[8 + field1.len()..]; +// +// // check field2 (i64) +// assert_eq!(&i64_oid.to_be_bytes(), &encoded[0..4]); +// assert_eq!(&8u32.to_be_bytes(), &encoded[4..8]); +// assert_eq!(field2.to_be_bytes(), &encoded[8..16]); +// let encoded = &encoded[16..]; +// +// // check field3 (bool) +// assert_eq!(&bool_oid.to_be_bytes(), &encoded[0..4]); +// assert_eq!(&1u32.to_be_bytes(), &encoded[4..8]); +// assert_eq!(field3, encoded[8] != 0); +// let encoded = &encoded[9..]; +// +// assert!(encoded.is_empty()); +// +// let string_size = >::size_hint(&field1); +// let i64_size = >::size_hint(&field2); +// let bool_size = >::size_hint(&field3); +// +// assert_eq!( +// 4 + 3 * (4 + 4) + string_size + i64_size + bool_size, +// example.size_hint() +// ); +// } #[test] #[cfg(feature = "mysql")] @@ -193,119 +194,125 @@ fn decode_transparent_mysql() { fn decode_transparent_postgres() { decode_with_db::(Transparent(0x1122_3344)); } - -#[test] -#[cfg(feature = "mysql")] -fn decode_weak_enum_mysql() { - decode_with_db::(Weak::One); - decode_with_db::(Weak::Two); - decode_with_db::(Weak::Three); -} - -#[test] -#[cfg(feature = "postgres")] -fn decode_weak_enum_postgres() { - decode_with_db::(Weak::One); - decode_with_db::(Weak::Two); - decode_with_db::(Weak::Three); -} - -#[test] -#[cfg(feature = "mysql")] -fn decode_strong_enum_mysql() { - decode_with_db::(Strong::One); - decode_with_db::(Strong::Two); - decode_with_db::(Strong::Three); -} - -#[test] -#[cfg(feature = "postgres")] -fn decode_strong_enum_postgres() { - decode_with_db::(Strong::One); - decode_with_db::(Strong::Two); - decode_with_db::(Strong::Three); -} - -#[test] -#[cfg(feature = "postgres")] -fn decode_struct_postgres() { - decode_with_db::(Struct { - field1: "Foo".to_string(), - field2: 3, - field3: true, - }); -} - +// +// #[test] +// #[cfg(feature = "mysql")] +// fn decode_weak_enum_mysql() { +// decode_with_db::(Weak::One); +// decode_with_db::(Weak::Two); +// decode_with_db::(Weak::Three); +// } +// +// #[test] +// #[cfg(feature = "postgres")] +// fn decode_weak_enum_postgres() { +// decode_with_db::(Weak::One); +// decode_with_db::(Weak::Two); +// decode_with_db::(Weak::Three); +// } +// +// #[test] +// #[cfg(feature = "mysql")] +// fn decode_strong_enum_mysql() { +// decode_with_db::(Strong::One); +// decode_with_db::(Strong::Two); +// decode_with_db::(Strong::Three); +// } +// +// #[test] +// #[cfg(feature = "postgres")] +// fn decode_strong_enum_postgres() { +// decode_with_db::(Strong::One); +// decode_with_db::(Strong::Two); +// decode_with_db::(Strong::Three); +// } +// +// #[test] +// #[cfg(feature = "postgres")] +// fn decode_struct_postgres() { +// decode_with_db::(Struct { +// field1: "Foo".to_string(), +// field2: 3, +// field3: true, +// }); +// } +// #[allow(dead_code)] -fn decode_with_db + Encode + PartialEq + Debug>(example: V) { +fn decode_with_db< + DB: sqlx::Database>, + V: for<'de> Decode<'de, DB> + Encode + PartialEq + Debug, +>( + example: V, +) { let mut encoded = Vec::new(); Encode::::encode(&example, &mut encoded); - let decoded = V::decode(&encoded).unwrap(); - assert_eq!(example, decoded); + // let decoded = V::decode(&encoded).unwrap(); + // assert_eq!(example, decoded); } #[test] #[cfg(feature = "mysql")] -fn has_sql_type_transparent_mysql() { - has_sql_type_transparent::(); +fn type_transparent_mysql() { + type_transparent::(); } #[test] #[cfg(feature = "postgres")] -fn has_sql_type_transparent_postgres() { - has_sql_type_transparent::(); +fn type_transparent_postgres() { + type_transparent::(); } #[allow(dead_code)] -fn has_sql_type_transparent() +fn type_transparent>>() where - DB: HasSqlType + HasSqlType, + Transparent: Type, + i32: Type, { - let info: DB::TypeInfo = >::type_info(); - let info_orig: DB::TypeInfo = >::type_info(); + let info: DB::TypeInfo = >::type_info(); + let info_orig: DB::TypeInfo = >::type_info(); assert!(info.compatible(&info_orig)); } - -#[test] -#[cfg(feature = "mysql")] -fn has_sql_type_weak_enum_mysql() { - has_sql_type_weak_enum::(); -} - -#[test] -#[cfg(feature = "postgres")] -fn has_sql_type_weak_enum_postgres() { - has_sql_type_weak_enum::(); -} - -#[allow(dead_code)] -fn has_sql_type_weak_enum() -where - DB: HasSqlType + HasSqlType, -{ - let info: DB::TypeInfo = >::type_info(); - let info_orig: DB::TypeInfo = >::type_info(); - assert!(info.compatible(&info_orig)); -} - -#[test] -#[cfg(feature = "mysql")] -fn has_sql_type_strong_enum_mysql() { - let info: sqlx::mysql::MySqlTypeInfo = >::type_info(); - assert!(info.compatible(&sqlx::mysql::MySqlTypeInfo::r#enum())) -} - -#[test] -#[cfg(feature = "postgres")] -fn has_sql_type_strong_enum_postgres() { - let info: sqlx::postgres::PgTypeInfo = >::type_info(); - assert!(info.compatible(&sqlx::postgres::PgTypeInfo::with_oid(10101010))) -} - -#[test] -#[cfg(feature = "postgres")] -fn has_sql_type_struct_postgres() { - let info: sqlx::postgres::PgTypeInfo = >::type_info(); - assert!(info.compatible(&sqlx::postgres::PgTypeInfo::with_oid(20202020))) -} +// +// #[test] +// #[cfg(feature = "mysql")] +// fn type_weak_enum_mysql() { +// type_weak_enum::(); +// } +// +// #[test] +// #[cfg(feature = "postgres")] +// fn type_weak_enum_postgres() { +// type_weak_enum::(); +// } +// +// #[allow(dead_code)] +// fn type_weak_enum>>() +// where +// DB: HasSqlType + HasSqlType, +// { +// let info: DB::TypeInfo = >::type_info(); +// let info_orig: DB::TypeInfo = >::type_info(); +// assert!(info.compatible(&info_orig)); +// } +// +// #[test] +// #[cfg(feature = "mysql")] +// fn type_strong_enum_mysql() { +// let info: sqlx::mysql::MySqlTypeInfo = >::type_info(); +// assert!(info.compatible(&sqlx::mysql::MySqlTypeInfo::r#enum())) +// } +// +// #[test] +// #[cfg(feature = "postgres")] +// fn type_strong_enum_postgres() { +// let info: sqlx::postgres::PgTypeInfo = >::type_info(); +// assert!(info.compatible(&sqlx::postgres::PgTypeInfo::with_oid(10101010))) +// } +// +// #[test] +// #[cfg(feature = "postgres")] +// fn type_struct_postgres() { +// let info: sqlx::postgres::PgTypeInfo = >::type_info(); +// assert!(info.compatible(&sqlx::postgres::PgTypeInfo::with_oid(20202020))) +// } From 100602187f6aa0f84208675f44d529641f2ae5cb Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Tue, 17 Mar 2020 16:33:39 -0700 Subject: [PATCH 19/25] memo: add more documentation to the database module --- sqlx-core/src/database.rs | 37 +++++++++++++++++++++++++++++++++++-- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/sqlx-core/src/database.rs b/sqlx-core/src/database.rs index 70dea800..c9a11dcb 100644 --- a/sqlx-core/src/database.rs +++ b/sqlx-core/src/database.rs @@ -9,7 +9,7 @@ use crate::types::TypeInfo; /// A database driver. /// /// This trait encapsulates a complete driver implementation to a specific -/// database (e.g., MySQL, Postgres). +/// database (e.g., **MySQL**, **Postgres**). pub trait Database where Self: Sized + Send + 'static, @@ -29,21 +29,54 @@ where /// The Rust type of table identifiers for this database. type TableId: Display + Clone; - type RawBuffer; + /// The Rust type used as the buffer when encoding arguments. + /// + /// For example, **Postgres** and **MySQL** use `Vec`; however, **SQLite** uses `Vec`. + type RawBuffer: Default; } +/// Associate [`Database`] with a `RawValue` of a generic lifetime. +/// +/// --- +/// +/// The upcoming Rust feature, [Generic Associated Types], should obviate +/// the need for this trait. +/// +/// [Generic Associated Types]: https://www.google.com/search?q=generic+associated+types+rust&oq=generic+associated+types+rust&aqs=chrome..69i57j0l5.3327j0j7&sourceid=chrome&ie=UTF-8 pub trait HasRawValue<'c> { + /// The Rust type used to hold a not-yet-decoded value that has just been + /// received from the database. + /// + /// For example, **Postgres** and **MySQL** use `&'c [u8]`; however, **SQLite** uses `SqliteValue<'c>`. type RawValue; } +/// Associate [`Database`] with a [`Cursor`] of a generic lifetime. +/// +/// --- +/// +/// The upcoming Rust feature, [Generic Associated Types], should obviate +/// the need for this trait. +/// +/// [Generic Associated Types]: https://www.google.com/search?q=generic+associated+types+rust&oq=generic+associated+types+rust&aqs=chrome..69i57j0l5.3327j0j7&sourceid=chrome&ie=UTF-8 pub trait HasCursor<'c, 'q> { type Database: Database; + /// The concrete `Cursor` implementation for this database. type Cursor: Cursor<'c, 'q, Database = Self::Database>; } +/// Associate [`Database`] with a [`Row`] of a generic lifetime. +/// +/// --- +/// +/// The upcoming Rust feature, [Generic Associated Types], should obviate +/// the need for this trait. +/// +/// [Generic Associated Types]: https://www.google.com/search?q=generic+associated+types+rust&oq=generic+associated+types+rust&aqs=chrome..69i57j0l5.3327j0j7&sourceid=chrome&ie=UTF-8 pub trait HasRow<'c> { type Database: Database; + /// The concrete `Row` implementation for this database. type Row: Row<'c, Database = Self::Database>; } From c1e6b2045c38675abaa61f6a1291c19fd3efe4b4 Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Tue, 17 Mar 2020 19:26:15 -0700 Subject: [PATCH 20/25] postgres: support null in a record --- sqlx-core/src/postgres/types/record.rs | 39 +++++++------------------- 1 file changed, 10 insertions(+), 29 deletions(-) diff --git a/sqlx-core/src/postgres/types/record.rs b/sqlx-core/src/postgres/types/record.rs index 0e40f1ef..c4955704 100644 --- a/sqlx-core/src/postgres/types/record.rs +++ b/sqlx-core/src/postgres/types/record.rs @@ -1,5 +1,5 @@ use crate::decode::Decode; -use crate::encode::Encode; +use crate::encode::{Encode, IsNull}; use crate::io::Buf; use crate::postgres::protocol::TypeId; use crate::postgres::{PgTypeInfo, PgValue, Postgres}; @@ -42,13 +42,16 @@ impl<'a> PgRecordEncoder<'a> { self.buf.extend(&[0; 4]); let start = self.buf.len(); - value.encode(self.buf); + if let IsNull::Yes = value.encode_nullable(self.buf) { + // replaces zeros with actual length + self.buf[start - 4..start].copy_from_slice(&(-1_i32).to_be_bytes()); + } else { + let end = self.buf.len(); + let size = end - start; - let end = self.buf.len(); - let size = end - start; - - // replaces zeros with actual length - self.buf[start - 4..start].copy_from_slice(&(size as u32).to_be_bytes()); + // replaces zeros with actual length + self.buf[start - 4..start].copy_from_slice(&(size as u32).to_be_bytes()); + } // keep track of count self.num += 1; @@ -57,28 +60,6 @@ impl<'a> PgRecordEncoder<'a> { } } -// impl Encode for (bool, i32, i64, f64, String) { -// fn encode(&self, buf: &mut Vec) { -// PgRecordEncoder::new(buf) -// .encode(self.0) -// .encode(self.1) -// .encode(self.2) -// .encode(self.3) -// .encode(&self.4) -// .finish() -// } -// -// fn size_hint(&self) -> usize { -// // for each field; oid, length, value -// 5 * (4 + 4) -// + (>::size_hint(&self.0) -// + >::size_hint(&self.1) -// + >::size_hint(&self.2) -// + >::size_hint(&self.3) -// + >::size_hint(&self.4)) -// } -// } - pub struct PgRecordDecoder<'de> { value: PgValue<'de>, } From 21059620dc6b4906a369d545887920eb8f207f09 Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Tue, 17 Mar 2020 19:26:31 -0700 Subject: [PATCH 21/25] mysql: support understanding ENUM as TEXT --- sqlx-core/src/mysql/protocol/row.rs | 1 + sqlx-core/src/mysql/types/mod.rs | 5 ----- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/sqlx-core/src/mysql/protocol/row.rs b/sqlx-core/src/mysql/protocol/row.rs index 384db51c..15096376 100644 --- a/sqlx-core/src/mysql/protocol/row.rs +++ b/sqlx-core/src/mysql/protocol/row.rs @@ -126,6 +126,7 @@ impl<'c> Row<'c> { | TypeId::LONG_BLOB | TypeId::CHAR | TypeId::TEXT + | TypeId::ENUM | TypeId::VAR_CHAR => { let (len_size, len) = get_lenenc(&buffer[index..]); diff --git a/sqlx-core/src/mysql/types/mod.rs b/sqlx-core/src/mysql/types/mod.rs index 343d39aa..2c4421f1 100644 --- a/sqlx-core/src/mysql/types/mod.rs +++ b/sqlx-core/src/mysql/types/mod.rs @@ -64,11 +64,6 @@ impl MySqlTypeInfo { _ => None, } } - - #[doc(hidden)] - pub fn r#enum() -> Self { - Self::new(TypeId::ENUM) - } } impl Display for MySqlTypeInfo { From d77b2b1e97aa91278167768196debe81fa8630c5 Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Tue, 17 Mar 2020 19:26:59 -0700 Subject: [PATCH 22/25] derives: update for new Decode/Encode traits and extensively test in usage --- Cargo.toml | 8 + sqlx-macros/src/derives/attributes.rs | 202 +++++++++------- sqlx-macros/src/derives/decode.rs | 92 ++++---- sqlx-macros/src/derives/encode.rs | 48 ++-- sqlx-macros/src/derives/mod.rs | 9 + sqlx-macros/src/derives/type.rs | 76 ++---- tests/derives.rs | 318 -------------------------- tests/mysql-derives.rs | 47 ++++ tests/postgres-derives.rs | 81 +++++++ tests/postgres.rs | 22 +- 10 files changed, 378 insertions(+), 525 deletions(-) delete mode 100644 tests/derives.rs create mode 100644 tests/mysql-derives.rs create mode 100644 tests/postgres-derives.rs diff --git a/Cargo.toml b/Cargo.toml index e3cd32fa..6e3443c6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -92,6 +92,10 @@ required-features = [ "mysql" ] name = "mysql-raw" required-features = [ "mysql" ] +[[test]] +name = "mysql-derives" +required-features = [ "mysql", "macros" ] + [[test]] name = "postgres" required-features = [ "postgres" ] @@ -104,6 +108,10 @@ required-features = [ "postgres" ] name = "postgres-types" required-features = [ "postgres" ] +[[test]] +name = "postgres-derives" +required-features = [ "postgres", "macros" ] + [[test]] name = "mysql-types" required-features = [ "mysql" ] diff --git a/sqlx-macros/src/derives/attributes.rs b/sqlx-macros/src/derives/attributes.rs index 72df6991..a75c82a8 100644 --- a/sqlx-macros/src/derives/attributes.rs +++ b/sqlx-macros/src/derives/attributes.rs @@ -11,33 +11,42 @@ macro_rules! assert_attribute { }; } -pub struct SqlxAttributes { +macro_rules! fail { + ($t:expr, $m:expr) => { + return Err(syn::Error::new_spanned($t, $m)); + }; +} + +macro_rules! try_set { + ($i:ident, $v:expr, $t:expr) => { + match $i { + None => $i = Some($v), + Some(_) => fail!($t, "duplicate attribute"), + } + }; +} + +#[derive(Copy, Clone)] +pub enum RenameAll { + LowerCase, +} + +pub struct SqlxContainerAttributes { pub transparent: bool, pub postgres_oid: Option, + pub rename_all: Option, pub repr: Option, +} + +pub struct SqlxChildAttributes { pub rename: Option, } -pub fn parse_attributes(input: &[Attribute]) -> syn::Result { +pub fn parse_container_attributes(input: &[Attribute]) -> syn::Result { let mut transparent = None; let mut postgres_oid = None; let mut repr = None; - let mut rename = None; - - macro_rules! fail { - ($t:expr, $m:expr) => { - return Err(syn::Error::new_spanned($t, $m)); - }; - } - - macro_rules! try_set { - ($i:ident, $v:expr, $t:expr) => { - match $i { - None => $i = Some($v), - Some(_) => fail!($t, "duplicate attribute"), - } - }; - } + let mut rename_all = None; for attr in input { let meta = attr @@ -51,11 +60,21 @@ pub fn parse_attributes(input: &[Attribute]) -> syn::Result { Meta::Path(p) if p.is_ident("transparent") => { try_set!(transparent, true, value) } + Meta::NameValue(MetaNameValue { path, lit: Lit::Str(val), .. - }) if path.is_ident("rename") => try_set!(rename, val.value(), value), + }) if path.is_ident("rename_all") => { + let val = match &*val.value() { + "lowercase" => RenameAll::LowerCase, + + _ => fail!(meta, "unexpected value for rename_all"), + }; + + try_set!(rename_all, val, value) + }, + Meta::List(list) if list.path.is_ident("postgres") => { for value in list.nested.iter() { match value { @@ -92,85 +111,93 @@ pub fn parse_attributes(input: &[Attribute]) -> syn::Result { } } - Ok(SqlxAttributes { + Ok(SqlxContainerAttributes { transparent: transparent.unwrap_or(false), postgres_oid, repr, + rename_all, + }) +} + +pub fn parse_child_attributes(input: &[Attribute]) -> syn::Result { + let mut rename = None; + + for attr in input { + let meta = attr + .parse_meta() + .map_err(|e| syn::Error::new_spanned(attr, e))?; + + match meta { + Meta::List(list) if list.path.is_ident("sqlx") => { + for value in list.nested.iter() { + match value { + NestedMeta::Meta(meta) => match meta { + Meta::NameValue(MetaNameValue { + path, + lit: Lit::Str(val), + .. + }) if path.is_ident("rename") => try_set!(rename, val.value(), value), + + u => fail!(u, "unexpected attribute"), + }, + u => fail!(u, "unexpected attribute"), + } + } + } + _ => {} + } + } + + Ok(SqlxChildAttributes { rename, }) } pub fn check_transparent_attributes(input: &DeriveInput, field: &Field) -> syn::Result<()> { - let attributes = parse_attributes(&input.attrs)?; + let attributes = parse_container_attributes(&input.attrs)?; + assert_attribute!( attributes.transparent, "expected #[sqlx(transparent)]", input ); + #[cfg(feature = "postgres")] assert_attribute!( attributes.postgres_oid.is_none(), "unexpected #[sqlx(postgres(oid = ..))]", input ); + assert_attribute!( - attributes.rename.is_none(), - "unexpected #[sqlx(rename = ..)]", + attributes.rename_all.is_none(), + "unexpected #[sqlx(rename_all = ..)]", field ); + assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input); - let attributes = parse_attributes(&field.attrs)?; - assert_attribute!( - !attributes.transparent, - "unexpected #[sqlx(transparent)]", - field - ); - #[cfg(feature = "postgres")] - assert_attribute!( - attributes.postgres_oid.is_none(), - "unexpected #[sqlx(postgres(oid = ..))]", - field - ); + + let attributes = parse_child_attributes(&field.attrs)?; + assert_attribute!( attributes.rename.is_none(), "unexpected #[sqlx(rename = ..)]", field ); - assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", field); + Ok(()) } pub fn check_enum_attributes<'a>( input: &'a DeriveInput, - variants: &Punctuated, -) -> syn::Result { - let attributes = parse_attributes(&input.attrs)?; +) -> syn::Result { + let attributes = parse_container_attributes(&input.attrs)?; + assert_attribute!( !attributes.transparent, "unexpected #[sqlx(transparent)]", input ); - assert_attribute!( - attributes.rename.is_none(), - "unexpected #[sqlx(rename = ..)]", - input - ); - - for variant in variants { - let attributes = parse_attributes(&variant.attrs)?; - assert_attribute!( - !attributes.transparent, - "unexpected #[sqlx(transparent)]", - variant - ); - #[cfg(feature = "postgres")] - assert_attribute!( - attributes.postgres_oid.is_none(), - "unexpected #[sqlx(postgres(oid = ..))]", - variant - ); - assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", variant); - } Ok(attributes) } @@ -178,83 +205,90 @@ pub fn check_enum_attributes<'a>( pub fn check_weak_enum_attributes( input: &DeriveInput, variants: &Punctuated, -) -> syn::Result { - let attributes = check_enum_attributes(input, variants)?; +) -> syn::Result { + let attributes = check_enum_attributes(input)?; + #[cfg(feature = "postgres")] assert_attribute!( attributes.postgres_oid.is_none(), "unexpected #[sqlx(postgres(oid = ..))]", input ); + assert_attribute!(attributes.repr.is_some(), "expected #[repr(..)]", input); + + assert_attribute!( + attributes.rename_all.is_none(), + "unexpected #[sqlx(c = ..)]", + input + ); + for variant in variants { - let attributes = parse_attributes(&variant.attrs)?; + let attributes = parse_child_attributes(&variant.attrs)?; + assert_attribute!( attributes.rename.is_none(), "unexpected #[sqlx(rename = ..)]", variant ); } - Ok(attributes.repr.unwrap()) + + Ok(attributes) } pub fn check_strong_enum_attributes( input: &DeriveInput, - variants: &Punctuated, -) -> syn::Result { - let attributes = check_enum_attributes(input, variants)?; + _variants: &Punctuated, +) -> syn::Result { + let attributes = check_enum_attributes(input)?; + #[cfg(feature = "postgres")] assert_attribute!( attributes.postgres_oid.is_some(), "expected #[sqlx(postgres(oid = ..))]", input ); + assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input); + Ok(attributes) } pub fn check_struct_attributes<'a>( input: &'a DeriveInput, fields: &Punctuated, -) -> syn::Result { - let attributes = parse_attributes(&input.attrs)?; +) -> syn::Result { + let attributes = parse_container_attributes(&input.attrs)?; + assert_attribute!( !attributes.transparent, "unexpected #[sqlx(transparent)]", input ); + #[cfg(feature = "postgres")] assert_attribute!( attributes.postgres_oid.is_some(), "expected #[sqlx(postgres(oid = ..))]", input ); + assert_attribute!( - attributes.rename.is_none(), - "unexpected #[sqlx(rename = ..)]", + attributes.rename_all.is_none(), + "unexpected #[sqlx(rename_all = ..)]", input ); + assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input); for field in fields { - let attributes = parse_attributes(&field.attrs)?; - assert_attribute!( - !attributes.transparent, - "unexpected #[sqlx(transparent)]", - field - ); - #[cfg(feature = "postgres")] - assert_attribute!( - attributes.postgres_oid.is_none(), - "unexpected #[sqlx(postgres(oid = ..))]", - field - ); + let attributes = parse_child_attributes(&field.attrs)?; + assert_attribute!( attributes.rename.is_none(), "unexpected #[sqlx(rename = ..)]", field ); - assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", field); } Ok(attributes) diff --git a/sqlx-macros/src/derives/decode.rs b/sqlx-macros/src/derives/decode.rs index b6b818ee..9dae6fa3 100644 --- a/sqlx-macros/src/derives/decode.rs +++ b/sqlx-macros/src/derives/decode.rs @@ -1,7 +1,9 @@ use super::attributes::{ check_strong_enum_attributes, check_struct_attributes, check_transparent_attributes, - check_weak_enum_attributes, parse_attributes, + check_weak_enum_attributes, parse_container_attributes, + parse_child_attributes, }; +use super::rename_all; use quote::quote; use syn::punctuated::Punctuated; use syn::token::Comma; @@ -11,7 +13,7 @@ use syn::{ }; pub fn expand_derive_decode(input: &DeriveInput) -> syn::Result { - let attrs = parse_attributes(&input.attrs)?; + let attrs = parse_container_attributes(&input.attrs)?; match &input.data { Data::Struct(DataStruct { fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), @@ -83,24 +85,29 @@ fn expand_derive_decode_weak_enum( input: &DeriveInput, variants: &Punctuated, ) -> syn::Result { - let repr = check_weak_enum_attributes(input, &variants)?; + let attr = check_weak_enum_attributes(input, &variants)?; + let repr = attr.repr.unwrap(); let ident = &input.ident; + let ident_s = ident.to_string(); + let arms = variants .iter() .map(|v| { let id = &v.ident; - parse_quote!(_ if (#ident :: #id as #repr) == val => Ok(#ident :: #id),) + parse_quote!(_ if (#ident :: #id as #repr) == value => Ok(#ident :: #id),) }) .collect::>(); Ok(quote!( - impl sqlx::decode::Decode for #ident where #repr: sqlx::decode::Decode { - fn decode(raw: &[u8]) -> std::result::Result { - let val = <#repr as sqlx::decode::Decode>::decode(raw)?; - match val { + impl<'de, DB: sqlx::Database> sqlx::decode::Decode<'de, DB> for #ident where #repr: sqlx::decode::Decode<'de, DB> { + fn decode(value: >::RawValue) -> sqlx::Result { + let value = <#repr as sqlx::decode::Decode<'de, DB>>::decode(value)?; + + match value { #(#arms)* - _ => Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid value"))) + + _ => Err(sqlx::Error::Decode(format!("invalid value {:?} for enum {}", value, #ident_s).into())) } } } @@ -111,29 +118,35 @@ fn expand_derive_decode_strong_enum( input: &DeriveInput, variants: &Punctuated, ) -> syn::Result { - check_strong_enum_attributes(input, &variants)?; + let cattr = check_strong_enum_attributes(input, &variants)?; let ident = &input.ident; + let ident_s = ident.to_string(); let value_arms = variants.iter().map(|v| -> Arm { let id = &v.ident; - let attributes = parse_attributes(&v.attrs).unwrap(); + let attributes = parse_child_attributes(&v.attrs).unwrap(); + if let Some(rename) = attributes.rename { parse_quote!(#rename => Ok(#ident :: #id),) + } else if let Some(pattern) = cattr.rename_all { + let name = rename_all(&*id.to_string(), pattern); + + parse_quote!(#name => Ok(#ident :: #id),) } else { let name = id.to_string(); parse_quote!(#name => Ok(#ident :: #id),) } }); - // TODO: prevent heap allocation Ok(quote!( - impl sqlx::decode::Decode for #ident where std::string::String: sqlx::decode::Decode { - fn decode(buf: &[u8]) -> std::result::Result { - let val = >::decode(buf)?; - match val.as_str() { + impl<'de, DB: sqlx::Database> sqlx::decode::Decode<'de, DB> for #ident where &'de str: sqlx::decode::Decode<'de, DB> { + fn decode(value: >::RawValue) -> sqlx::Result { + let value = <&'de str as sqlx::decode::Decode<'de, DB>>::decode(value)?; + match value { #(#value_arms)* - _ => Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid value"))) + + _ => Err(sqlx::Error::Decode(format!("invalid value {:?} for enum {}", value, #ident_s).into())) } } } @@ -151,57 +164,50 @@ fn expand_derive_decode_struct( if cfg!(feature = "postgres") { let ident = &input.ident; - let column_count = fields.len(); - // extract type generics let generics = &input.generics; let (_, ty_generics, _) = generics.split_for_impl(); // add db type for impl generics & where clause let mut generics = generics.clone(); + generics.params.insert(0, parse_quote!('de)); + let predicates = &mut generics.make_where_clause().predicates; + for field in fields { let ty = &field.ty; - predicates.push(parse_quote!(#ty: sqlx::decode::Decode)); + + predicates.push(parse_quote!(#ty: sqlx::decode::Decode<'de, sqlx::Postgres>)); predicates.push(parse_quote!(#ty: sqlx::types::Type)); } + let (impl_generics, _, where_clause) = generics.split_for_impl(); let reads = fields.iter().map(|field| -> Stmt { let id = &field.ident; let ty = &field.ty; + parse_quote!( - let #id = sqlx::postgres::decode_struct_field::<#ty>(&mut buf)?; + let #id = decoder.decode::<#ty>()?; ) }); let names = fields.iter().map(|field| &field.ident); tts.extend(quote!( - impl #impl_generics sqlx::decode::Decode for #ident#ty_generics #where_clause { - fn decode(buf: &[u8]) -> std::result::Result { - if buf.len() < 4 { - return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Not enough data sent"))); + impl #impl_generics sqlx::decode::Decode<'de, sqlx::Postgres> for #ident #ty_generics #where_clause { + fn decode(value: >::RawValue) -> sqlx::Result { + let mut decoder = sqlx::postgres::types::PgRecordDecoder::new(value)?; + + #(#reads)* + + Ok(#ident { + #(#names),* + }) } - - let column_count = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[..4]).unwrap()) as usize; - if column_count != #column_count { - return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid column count"))); - } - let mut buf = &buf[4..]; - - #(#reads)* - - if !buf.is_empty() { - return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new(format!("Too much data sent ({} bytes left)", buf.len())))); - } - - Ok(#ident { - #(#names),* - }) } - } - )) + )); } + Ok(tts) } diff --git a/sqlx-macros/src/derives/encode.rs b/sqlx-macros/src/derives/encode.rs index 56ad9fb7..2de0064d 100644 --- a/sqlx-macros/src/derives/encode.rs +++ b/sqlx-macros/src/derives/encode.rs @@ -1,7 +1,8 @@ use super::attributes::{ check_strong_enum_attributes, check_struct_attributes, check_transparent_attributes, - check_weak_enum_attributes, parse_attributes, + check_weak_enum_attributes, parse_container_attributes, parse_child_attributes, }; +use super::rename_all; use quote::quote; use syn::punctuated::Punctuated; use syn::token::Comma; @@ -11,7 +12,7 @@ use syn::{ }; pub fn expand_derive_encode(input: &DeriveInput) -> syn::Result { - let args = parse_attributes(&input.attrs)?; + let args = parse_container_attributes(&input.attrs)?; match &input.data { Data::Struct(DataStruct { @@ -87,18 +88,21 @@ fn expand_derive_encode_weak_enum( input: &DeriveInput, variants: &Punctuated, ) -> syn::Result { - let repr = check_weak_enum_attributes(input, &variants)?; + let attr = check_weak_enum_attributes(input, &variants)?; + let repr = attr.repr.unwrap(); let ident = &input.ident; Ok(quote!( impl sqlx::encode::Encode for #ident where #repr: sqlx::encode::Encode { - fn encode(&self, buf: &mut std::vec::Vec) { + fn encode(&self, buf: &mut DB::RawBuffer) { sqlx::encode::Encode::encode(&(*self as #repr), buf) } - fn encode_nullable(&self, buf: &mut std::vec::Vec) -> sqlx::encode::IsNull { + + fn encode_nullable(&self, buf: &mut DB::RawBuffer) -> sqlx::encode::IsNull { sqlx::encode::Encode::encode_nullable(&(*self as #repr), buf) } + fn size_hint(&self) -> usize { sqlx::encode::Encode::size_hint(&(*self as #repr)) } @@ -110,16 +114,21 @@ fn expand_derive_encode_strong_enum( input: &DeriveInput, variants: &Punctuated, ) -> syn::Result { - check_strong_enum_attributes(input, &variants)?; + let cattr = check_strong_enum_attributes(input, &variants)?; let ident = &input.ident; let mut value_arms = Vec::new(); for v in variants { let id = &v.ident; - let attributes = parse_attributes(&v.attrs)?; + let attributes = parse_child_attributes(&v.attrs)?; + if let Some(rename) = attributes.rename { value_arms.push(quote!(#ident :: #id => #rename,)); + } else if let Some(pattern) = cattr.rename_all { + let name = rename_all(&*id.to_string(), pattern); + + value_arms.push(quote!(#ident :: #id => #name,)); } else { let name = id.to_string(); value_arms.push(quote!(#ident :: #id => #name,)); @@ -128,12 +137,13 @@ fn expand_derive_encode_strong_enum( Ok(quote!( impl sqlx::encode::Encode for #ident where str: sqlx::encode::Encode { - fn encode(&self, buf: &mut std::vec::Vec) { + fn encode(&self, buf: &mut DB::RawBuffer) { let val = match self { #(#value_arms)* }; >::encode(val, buf) } + fn size_hint(&self) -> usize { let val = match self { #(#value_arms)* @@ -154,7 +164,6 @@ fn expand_derive_encode_struct( if cfg!(feature = "postgres") { let ident = &input.ident; - let column_count = fields.len(); // extract type generics @@ -164,23 +173,29 @@ fn expand_derive_encode_struct( // add db type for impl generics & where clause let mut generics = generics.clone(); let predicates = &mut generics.make_where_clause().predicates; + for field in fields { let ty = &field.ty; + predicates.push(parse_quote!(#ty: sqlx::encode::Encode)); - predicates.push(parse_quote!(sqlx::Postgres: sqlx::types::HasSqlType<#ty>)); + predicates.push(parse_quote!(#ty: sqlx::types::Type)); } + let (impl_generics, _, where_clause) = generics.split_for_impl(); let writes = fields.iter().map(|field| -> Stmt { let id = &field.ident; + parse_quote!( - sqlx::postgres::encode_struct_field(buf, &self. #id); + // sqlx::postgres::encode_struct_field(buf, &self. #id); + encoder.encode(&self. #id); ) }); let sizes = fields.iter().map(|field| -> Expr { let id = &field.ident; let ty = &field.ty; + parse_quote!( <#ty as sqlx::encode::Encode>::size_hint(&self. #id) ) @@ -189,13 +204,16 @@ fn expand_derive_encode_struct( tts.extend(quote!( impl #impl_generics sqlx::encode::Encode for #ident #ty_generics #where_clause { fn encode(&self, buf: &mut std::vec::Vec) { - buf.extend(&(#column_count as u32).to_be_bytes()); + let mut encoder = sqlx::postgres::types::PgRecordEncoder::new(buf); + #(#writes)* + + encoder.finish(); } + fn size_hint(&self) -> usize { - 4 // oid - + #column_count * (4 + 4) // oid (int) and length (int) for each column - + #(#sizes)+* // sum of the size hints for each column + #column_count * (4 + 4) // oid (int) and length (int) for each column + + #(#sizes)+* // sum of the size hints for each column } } )); diff --git a/sqlx-macros/src/derives/mod.rs b/sqlx-macros/src/derives/mod.rs index 4e36533d..888b737e 100644 --- a/sqlx-macros/src/derives/mod.rs +++ b/sqlx-macros/src/derives/mod.rs @@ -7,6 +7,7 @@ pub(crate) use decode::expand_derive_decode; pub(crate) use encode::expand_derive_encode; pub(crate) use r#type::expand_derive_type; +use self::attributes::RenameAll; use std::iter::FromIterator; use syn::DeriveInput; @@ -23,3 +24,11 @@ pub(crate) fn expand_derive_type_encode_decode( Ok(combined) } + +pub(crate) fn rename_all(s: &str, pattern: RenameAll) -> String { + match pattern { + RenameAll::LowerCase => { + s.to_lowercase() + } + } +} diff --git a/sqlx-macros/src/derives/type.rs b/sqlx-macros/src/derives/type.rs index 0c86cb8a..9fbec169 100644 --- a/sqlx-macros/src/derives/type.rs +++ b/sqlx-macros/src/derives/type.rs @@ -1,6 +1,6 @@ use super::attributes::{ check_strong_enum_attributes, check_struct_attributes, check_transparent_attributes, - check_weak_enum_attributes, parse_attributes, + check_weak_enum_attributes, parse_container_attributes, }; use quote::quote; use syn::punctuated::Punctuated; @@ -11,7 +11,7 @@ use syn::{ }; pub fn expand_derive_type(input: &DeriveInput) -> syn::Result { - let attrs = parse_attributes(&input.attrs)?; + let attrs = parse_container_attributes(&input.attrs)?; match &input.data { Data::Struct(DataStruct { fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), @@ -65,64 +65,36 @@ fn expand_derive_has_sql_type_transparent( .make_where_clause() .predicates .push(parse_quote!(#ty: sqlx::types::Type)); + let (impl_generics, _, where_clause) = generics.split_for_impl(); - let mut tts = proc_macro2::TokenStream::new(); - - // if cfg!(feature = "mysql") { - tts.extend(quote!( + Ok(quote!( impl #impl_generics sqlx::types::Type< DB > for #ident #ty_generics #where_clause { fn type_info() -> DB::TypeInfo { <#ty as sqlx::Type>::type_info() } } - )); - - // } - - // if cfg!(feature = "postgres") { - // tts.extend(quote!( - // impl #impl_generics sqlx::types::HasSqlType< sqlx::Postgres > #ident #ty_generics #where_clause { - // fn type_info() -> Self::TypeInfo { - // >::type_info() - // } - // } - // )); - // } - - Ok(tts) + )) } fn expand_derive_has_sql_type_weak_enum( input: &DeriveInput, variants: &Punctuated, ) -> syn::Result { - let repr = check_weak_enum_attributes(input, variants)?; - + let attr = check_weak_enum_attributes(input, variants)?; + let repr = attr.repr.unwrap(); let ident = &input.ident; - let mut tts = proc_macro2::TokenStream::new(); - if cfg!(feature = "mysql") { - tts.extend(quote!( - impl sqlx::types::HasSqlType< #ident > for sqlx::MySql where Self: sqlx::types::HasSqlType< #repr > { - fn type_info() -> Self::TypeInfo { - >::type_info() - } + Ok(quote!( + impl sqlx::Type for #ident + where + #repr: sqlx::Type, + { + fn type_info() -> DB::TypeInfo { + <#repr as sqlx::Type>::type_info() } - )); - } - - if cfg!(feature = "postgres") { - tts.extend(quote!( - impl sqlx::types::HasSqlType< #ident > for sqlx::Postgres where Self: sqlx::types::HasSqlType< #repr > { - fn type_info() -> Self::TypeInfo { - >::type_info() - } - } - )); - } - - Ok(tts) + } + )) } fn expand_derive_has_sql_type_strong_enum( @@ -136,9 +108,11 @@ fn expand_derive_has_sql_type_strong_enum( if cfg!(feature = "mysql") { tts.extend(quote!( - impl sqlx::types::HasSqlType< #ident > for sqlx::MySql { - fn type_info() -> Self::TypeInfo { - sqlx::mysql::MySqlTypeInfo::r#enum() + impl sqlx::Type< sqlx::MySql > for #ident { + fn type_info() -> sqlx::mysql::MySqlTypeInfo { + // This is really fine, MySQL is loosely typed and + // we don't nede to be specific here + >::type_info() } } )); @@ -147,8 +121,8 @@ fn expand_derive_has_sql_type_strong_enum( if cfg!(feature = "postgres") { let oid = attributes.postgres_oid.unwrap(); tts.extend(quote!( - impl sqlx::types::HasSqlType< #ident > for sqlx::Postgres { - fn type_info() -> Self::TypeInfo { + impl sqlx::Type< sqlx::Postgres > for #ident { + fn type_info() -> sqlx::postgres::PgTypeInfo { sqlx::postgres::PgTypeInfo::with_oid(#oid) } } @@ -170,8 +144,8 @@ fn expand_derive_has_sql_type_struct( if cfg!(feature = "postgres") { let oid = attributes.postgres_oid.unwrap(); tts.extend(quote!( - impl sqlx::types::HasSqlType< #ident > for sqlx::Postgres { - fn type_info() -> Self::TypeInfo { + impl sqlx::types::Type< sqlx::Postgres > for #ident { + fn type_info() -> sqlx::postgres::PgTypeInfo { sqlx::postgres::PgTypeInfo::with_oid(#oid) } } diff --git a/tests/derives.rs b/tests/derives.rs deleted file mode 100644 index 5b5bd289..00000000 --- a/tests/derives.rs +++ /dev/null @@ -1,318 +0,0 @@ -use sqlx::decode::Decode; -use sqlx::encode::Encode; -use sqlx::types::TypeInfo; -use sqlx::Type; -use std::fmt::Debug; - -#[derive(PartialEq, Debug, Type)] -#[sqlx(transparent)] -struct Transparent(i32); - -// #[derive(PartialEq, Debug, Clone, Copy, Encode, Decode, HasSqlType)] -// #[repr(i32)] -// #[allow(dead_code)] -// enum Weak { -// One, -// Two, -// Three, -// } -// -// #[derive(PartialEq, Debug, Encode, Decode, HasSqlType)] -// #[sqlx(postgres(oid = 10101010))] -// #[allow(dead_code)] -// enum Strong { -// One, -// Two, -// #[sqlx(rename = "four")] -// Three, -// } -// -// #[derive(PartialEq, Debug, Encode, Decode, HasSqlType)] -// #[sqlx(postgres(oid = 20202020))] -// #[allow(dead_code)] -// struct Struct { -// field1: String, -// field2: i64, -// field3: bool, -// } - -#[test] -#[cfg(feature = "mysql")] -fn encode_transparent_mysql() { - encode_transparent::(); -} - -#[test] -#[cfg(feature = "postgres")] -fn encode_transparent_postgres() { - encode_transparent::(); -} - -#[allow(dead_code)] -fn encode_transparent>>() -where - Transparent: Encode, - i32: Encode, -{ - let example = Transparent(0x1122_3344); - - let mut encoded = Vec::new(); - let mut encoded_orig = Vec::new(); - - Encode::::encode(&example, &mut encoded); - Encode::::encode(&example.0, &mut encoded_orig); - - assert_eq!(encoded, encoded_orig); -} -// -// #[test] -// #[cfg(feature = "mysql")] -// fn encode_weak_enum_mysql() { -// encode_weak_enum::(); -// } -// -// #[test] -// #[cfg(feature = "postgres")] -// fn encode_weak_enum_postgres() { -// encode_weak_enum::(); -// } -// -// #[allow(dead_code)] -// fn encode_weak_enum>>() -// where -// Weak: Encode, -// i32: Encode, -// { -// for example in [Weak::One, Weak::Two, Weak::Three].iter() { -// let mut encoded = Vec::new(); -// let mut encoded_orig = Vec::new(); -// -// Encode::::encode(example, &mut encoded); -// Encode::::encode(&(*example as i32), &mut encoded_orig); -// -// assert_eq!(encoded, encoded_orig); -// } -// } -// -// #[test] -// #[cfg(feature = "mysql")] -// fn encode_strong_enum_mysql() { -// encode_strong_enum::(); -// } -// -// #[test] -// #[cfg(feature = "postgres")] -// fn encode_strong_enum_postgres() { -// encode_strong_enum::(); -// } -// -// #[allow(dead_code)] -// fn encode_strong_enum>>() -// where -// Strong: Encode, -// str: Encode, -// { -// for (example, name) in [ -// (Strong::One, "One"), -// (Strong::Two, "Two"), -// (Strong::Three, "four"), -// ] -// .iter() -// { -// let mut encoded = Vec::new(); -// let mut encoded_orig = Vec::new(); -// -// Encode::::encode(example, &mut encoded); -// Encode::::encode(*name, &mut encoded_orig); -// -// assert_eq!(encoded, encoded_orig); -// } -// } -// -// #[test] -// #[cfg(feature = "postgres")] -// fn encode_struct_postgres() { -// let field1 = "Foo".to_string(); -// let field2 = 3; -// let field3 = false; -// -// let example = Struct { -// field1: field1.clone(), -// field2, -// field3, -// }; -// -// let mut encoded = Vec::new(); -// Encode::::encode(&example, &mut encoded); -// -// let string_oid = >::type_info().oid(); -// let i64_oid = >::type_info().oid(); -// let bool_oid = >::type_info().oid(); -// -// // 3 columns -// assert_eq!(&[0, 0, 0, 3], &encoded[..4]); -// let encoded = &encoded[4..]; -// -// // check field1 (string) -// assert_eq!(&string_oid.to_be_bytes(), &encoded[0..4]); -// assert_eq!(&(field1.len() as u32).to_be_bytes(), &encoded[4..8]); -// assert_eq!(field1.as_bytes(), &encoded[8..8 + field1.len()]); -// let encoded = &encoded[8 + field1.len()..]; -// -// // check field2 (i64) -// assert_eq!(&i64_oid.to_be_bytes(), &encoded[0..4]); -// assert_eq!(&8u32.to_be_bytes(), &encoded[4..8]); -// assert_eq!(field2.to_be_bytes(), &encoded[8..16]); -// let encoded = &encoded[16..]; -// -// // check field3 (bool) -// assert_eq!(&bool_oid.to_be_bytes(), &encoded[0..4]); -// assert_eq!(&1u32.to_be_bytes(), &encoded[4..8]); -// assert_eq!(field3, encoded[8] != 0); -// let encoded = &encoded[9..]; -// -// assert!(encoded.is_empty()); -// -// let string_size = >::size_hint(&field1); -// let i64_size = >::size_hint(&field2); -// let bool_size = >::size_hint(&field3); -// -// assert_eq!( -// 4 + 3 * (4 + 4) + string_size + i64_size + bool_size, -// example.size_hint() -// ); -// } - -#[test] -#[cfg(feature = "mysql")] -fn decode_transparent_mysql() { - decode_with_db::(Transparent(0x1122_3344)); -} - -#[test] -#[cfg(feature = "postgres")] -fn decode_transparent_postgres() { - decode_with_db::(Transparent(0x1122_3344)); -} -// -// #[test] -// #[cfg(feature = "mysql")] -// fn decode_weak_enum_mysql() { -// decode_with_db::(Weak::One); -// decode_with_db::(Weak::Two); -// decode_with_db::(Weak::Three); -// } -// -// #[test] -// #[cfg(feature = "postgres")] -// fn decode_weak_enum_postgres() { -// decode_with_db::(Weak::One); -// decode_with_db::(Weak::Two); -// decode_with_db::(Weak::Three); -// } -// -// #[test] -// #[cfg(feature = "mysql")] -// fn decode_strong_enum_mysql() { -// decode_with_db::(Strong::One); -// decode_with_db::(Strong::Two); -// decode_with_db::(Strong::Three); -// } -// -// #[test] -// #[cfg(feature = "postgres")] -// fn decode_strong_enum_postgres() { -// decode_with_db::(Strong::One); -// decode_with_db::(Strong::Two); -// decode_with_db::(Strong::Three); -// } -// -// #[test] -// #[cfg(feature = "postgres")] -// fn decode_struct_postgres() { -// decode_with_db::(Struct { -// field1: "Foo".to_string(), -// field2: 3, -// field3: true, -// }); -// } -// -#[allow(dead_code)] -fn decode_with_db< - DB: sqlx::Database>, - V: for<'de> Decode<'de, DB> + Encode + PartialEq + Debug, ->( - example: V, -) { - let mut encoded = Vec::new(); - Encode::::encode(&example, &mut encoded); - - // let decoded = V::decode(&encoded).unwrap(); - // assert_eq!(example, decoded); -} - -#[test] -#[cfg(feature = "mysql")] -fn type_transparent_mysql() { - type_transparent::(); -} - -#[test] -#[cfg(feature = "postgres")] -fn type_transparent_postgres() { - type_transparent::(); -} - -#[allow(dead_code)] -fn type_transparent>>() -where - Transparent: Type, - i32: Type, -{ - let info: DB::TypeInfo = >::type_info(); - let info_orig: DB::TypeInfo = >::type_info(); - assert!(info.compatible(&info_orig)); -} -// -// #[test] -// #[cfg(feature = "mysql")] -// fn type_weak_enum_mysql() { -// type_weak_enum::(); -// } -// -// #[test] -// #[cfg(feature = "postgres")] -// fn type_weak_enum_postgres() { -// type_weak_enum::(); -// } -// -// #[allow(dead_code)] -// fn type_weak_enum>>() -// where -// DB: HasSqlType + HasSqlType, -// { -// let info: DB::TypeInfo = >::type_info(); -// let info_orig: DB::TypeInfo = >::type_info(); -// assert!(info.compatible(&info_orig)); -// } -// -// #[test] -// #[cfg(feature = "mysql")] -// fn type_strong_enum_mysql() { -// let info: sqlx::mysql::MySqlTypeInfo = >::type_info(); -// assert!(info.compatible(&sqlx::mysql::MySqlTypeInfo::r#enum())) -// } -// -// #[test] -// #[cfg(feature = "postgres")] -// fn type_strong_enum_postgres() { -// let info: sqlx::postgres::PgTypeInfo = >::type_info(); -// assert!(info.compatible(&sqlx::postgres::PgTypeInfo::with_oid(10101010))) -// } -// -// #[test] -// #[cfg(feature = "postgres")] -// fn type_struct_postgres() { -// let info: sqlx::postgres::PgTypeInfo = >::type_info(); -// assert!(info.compatible(&sqlx::postgres::PgTypeInfo::with_oid(20202020))) -// } diff --git a/tests/mysql-derives.rs b/tests/mysql-derives.rs new file mode 100644 index 00000000..c92d06f9 --- /dev/null +++ b/tests/mysql-derives.rs @@ -0,0 +1,47 @@ +use sqlx_test::test_type; +use std::fmt::Debug; +use sqlx::MySql; + +// Transparent types are rust-side wrappers over DB types +#[derive(PartialEq, Debug, sqlx::Type)] +#[sqlx(transparent)] +struct Transparent(i32); + +// "Weak" enums map to an integer type indicated by #[repr] +#[derive(PartialEq, Copy, Clone, Debug, sqlx::Type)] +#[repr(i32)] +enum Weak { + One = 0, + Two = 2, + Three = 4, +} + +// "Strong" enums can map to TEXT or a custom enum +#[derive(PartialEq, Debug, sqlx::Type)] +#[sqlx(rename_all = "lowercase")] +enum Color { + Red, + Green, + Blue, +} + +test_type!(transparent( + MySql, + Transparent, + "0" == Transparent(0), + "23523" == Transparent(23523) +)); + +test_type!(weak_enum( + MySql, + Weak, + "0" == Weak::One, + "2" == Weak::Two, + "4" == Weak::Three +)); + +test_type!(strong_color_enum( + MySql, + Color, + "'green'" == Color::Green +)); diff --git a/tests/postgres-derives.rs b/tests/postgres-derives.rs new file mode 100644 index 00000000..0a3504d3 --- /dev/null +++ b/tests/postgres-derives.rs @@ -0,0 +1,81 @@ +use sqlx_test::test_type; +use std::fmt::Debug; +use sqlx::Postgres; + +// Transparent types are rust-side wrappers over DB types +#[derive(PartialEq, Debug, sqlx::Type)] +#[sqlx(transparent)] +struct Transparent(i32); + +// "Weak" enums map to an integer type indicated by #[repr] +#[derive(PartialEq, Copy, Clone, Debug, sqlx::Type)] +#[repr(i32)] +enum Weak { + One = 0, + Two = 2, + Three = 4, +} + +// "Strong" enums can map to TEXT (25) or a custom enum type +#[derive(PartialEq, Debug, sqlx::Type)] +#[sqlx(postgres(oid = 25))] +#[sqlx(rename_all = "lowercase")] +enum Strong { + One, + Two, + + #[sqlx(rename = "four")] + Three, +} + +// Records must map to a custom type +// Note that all types are types in Postgres +#[derive(PartialEq, Debug, sqlx::Type)] +#[sqlx(postgres(oid = 12184))] +struct PgConfig { + name: String, + setting: Option, +} + +test_type!(transparent( + Postgres, + Transparent, + "0" == Transparent(0), + "23523" == Transparent(23523) +)); + +test_type!(weak_enum( + Postgres, + Weak, + "0::int4" == Weak::One, + "2::int4" == Weak::Two, + "4::int4" == Weak::Three +)); + +test_type!(strong_enum( + Postgres, + Strong, + "'one'::text" == Strong::One, + "'two'::text" == Strong::Two, + "'four'::text" == Strong::Three +)); + +test_type!(record_pg_config( + Postgres, + PgConfig, + // (CC,gcc) + "(SELECT ROW('CC', 'gcc')::pg_config)" == PgConfig { + name: "CC".to_owned(), + setting: Some("gcc".to_owned()), + }, + // (CC,) + "(SELECT '(\"CC\",)'::pg_config)" == PgConfig { + name: "CC".to_owned(), + setting: None, + }, + // (CC,"") + "(SELECT '(\"CC\",\"\")'::pg_config)" == PgConfig { + name: "CC".to_owned(), + setting: Some("".to_owned()), + } +)); diff --git a/tests/postgres.rs b/tests/postgres.rs index 2eb1c922..5d2b7d59 100644 --- a/tests/postgres.rs +++ b/tests/postgres.rs @@ -1,12 +1,13 @@ use futures::TryStreamExt; +use sqlx_test::new; use sqlx::postgres::{PgPool, PgQueryAs, PgRow}; -use sqlx::{postgres::PgConnection, Connect, Connection, Executor, Row}; +use sqlx::{Postgres, Connection, Executor, Row}; use std::time::Duration; #[cfg_attr(feature = "runtime-async-std", async_std::test)] #[cfg_attr(feature = "runtime-tokio", tokio::test)] async fn it_connects() -> anyhow::Result<()> { - let mut conn = connect().await?; + let mut conn = new::().await?; let value = sqlx::query("select 1 + 1") .try_map(|row: PgRow| row.try_get::(0)) @@ -21,7 +22,7 @@ async fn it_connects() -> anyhow::Result<()> { #[cfg_attr(feature = "runtime-async-std", async_std::test)] #[cfg_attr(feature = "runtime-tokio", tokio::test)] async fn it_executes() -> anyhow::Result<()> { - let mut conn = connect().await?; + let mut conn = new::().await?; let _ = conn .execute( @@ -55,7 +56,7 @@ CREATE TEMPORARY TABLE users (id INTEGER PRIMARY KEY); #[cfg_attr(feature = "runtime-async-std", async_std::test)] #[cfg_attr(feature = "runtime-tokio", tokio::test)] async fn it_can_return_interleaved_nulls_issue_104() -> anyhow::Result<()> { - let mut conn = connect().await?; + let mut conn = new::().await?; let tuple = sqlx::query("SELECT NULL::INT, 10::INT, NULL, 20::INT, NULL, 40::INT, NULL, 80::INT") @@ -89,7 +90,7 @@ async fn it_can_return_interleaved_nulls_issue_104() -> anyhow::Result<()> { #[cfg_attr(feature = "runtime-async-std", async_std::test)] #[cfg_attr(feature = "runtime-tokio", tokio::test)] async fn it_can_work_with_transactions() -> anyhow::Result<()> { - let mut conn = connect().await?; + let mut conn = new::().await?; conn.execute("CREATE TABLE IF NOT EXISTS _sqlx_users_1922 (id INTEGER PRIMARY KEY)") .await?; @@ -141,7 +142,7 @@ async fn it_can_work_with_transactions() -> anyhow::Result<()> { .await?; } - conn = connect().await?; + conn = new::().await?; let (count,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM _sqlx_users_1922") .fetch_one(&mut conn) @@ -210,7 +211,7 @@ async fn pool_smoke_test() -> anyhow::Result<()> { #[cfg_attr(feature = "runtime-async-std", async_std::test)] #[cfg_attr(feature = "runtime-tokio", tokio::test)] async fn test_describe() -> anyhow::Result<()> { - let mut conn = connect().await?; + let mut conn = new::().await?; let _ = conn .execute( @@ -239,10 +240,3 @@ async fn test_describe() -> anyhow::Result<()> { Ok(()) } - -async fn connect() -> anyhow::Result { - let _ = dotenv::dotenv(); - let _ = env_logger::try_init(); - - Ok(PgConnection::connect(dotenv::var("DATABASE_URL")?).await?) -} From f7e08ea4d83489472c3bee0743c11228081aff32 Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Tue, 17 Mar 2020 19:33:20 -0700 Subject: [PATCH 23/25] Remove mention of old derives test --- Cargo.toml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 6e3443c6..a27399e9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -115,7 +115,3 @@ required-features = [ "postgres", "macros" ] [[test]] name = "mysql-types" required-features = [ "mysql" ] - -[[test]] -name = "derives" -required-features = [ "macros" ] From ff722d0e627d18a1415ef95833dee9c770060a5e Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Tue, 17 Mar 2020 19:35:01 -0700 Subject: [PATCH 24/25] Run rustfmt --- sqlx-macros/src/derives/attributes.rs | 10 +++------ sqlx-macros/src/derives/decode.rs | 3 +-- sqlx-macros/src/derives/encode.rs | 2 +- sqlx-macros/src/derives/mod.rs | 4 +--- tests/mysql-derives.rs | 8 ++------ tests/postgres-derives.rs | 29 +++++++++++++++------------ tests/postgres.rs | 4 ++-- 7 files changed, 26 insertions(+), 34 deletions(-) diff --git a/sqlx-macros/src/derives/attributes.rs b/sqlx-macros/src/derives/attributes.rs index a75c82a8..3c3d1f66 100644 --- a/sqlx-macros/src/derives/attributes.rs +++ b/sqlx-macros/src/derives/attributes.rs @@ -73,7 +73,7 @@ pub fn parse_container_attributes(input: &[Attribute]) -> syn::Result { for value in list.nested.iter() { @@ -148,9 +148,7 @@ pub fn parse_child_attributes(input: &[Attribute]) -> syn::Result syn::Result<()> { @@ -188,9 +186,7 @@ pub fn check_transparent_attributes(input: &DeriveInput, field: &Field) -> syn:: Ok(()) } -pub fn check_enum_attributes<'a>( - input: &'a DeriveInput, -) -> syn::Result { +pub fn check_enum_attributes<'a>(input: &'a DeriveInput) -> syn::Result { let attributes = parse_container_attributes(&input.attrs)?; assert_attribute!( diff --git a/sqlx-macros/src/derives/decode.rs b/sqlx-macros/src/derives/decode.rs index 9dae6fa3..5b599019 100644 --- a/sqlx-macros/src/derives/decode.rs +++ b/sqlx-macros/src/derives/decode.rs @@ -1,7 +1,6 @@ use super::attributes::{ check_strong_enum_attributes, check_struct_attributes, check_transparent_attributes, - check_weak_enum_attributes, parse_container_attributes, - parse_child_attributes, + check_weak_enum_attributes, parse_child_attributes, parse_container_attributes, }; use super::rename_all; use quote::quote; diff --git a/sqlx-macros/src/derives/encode.rs b/sqlx-macros/src/derives/encode.rs index 2de0064d..e6b2c237 100644 --- a/sqlx-macros/src/derives/encode.rs +++ b/sqlx-macros/src/derives/encode.rs @@ -1,6 +1,6 @@ use super::attributes::{ check_strong_enum_attributes, check_struct_attributes, check_transparent_attributes, - check_weak_enum_attributes, parse_container_attributes, parse_child_attributes, + check_weak_enum_attributes, parse_child_attributes, parse_container_attributes, }; use super::rename_all; use quote::quote; diff --git a/sqlx-macros/src/derives/mod.rs b/sqlx-macros/src/derives/mod.rs index 888b737e..2d8208f2 100644 --- a/sqlx-macros/src/derives/mod.rs +++ b/sqlx-macros/src/derives/mod.rs @@ -27,8 +27,6 @@ pub(crate) fn expand_derive_type_encode_decode( pub(crate) fn rename_all(s: &str, pattern: RenameAll) -> String { match pattern { - RenameAll::LowerCase => { - s.to_lowercase() - } + RenameAll::LowerCase => s.to_lowercase(), } } diff --git a/tests/mysql-derives.rs b/tests/mysql-derives.rs index c92d06f9..0e875e8c 100644 --- a/tests/mysql-derives.rs +++ b/tests/mysql-derives.rs @@ -1,6 +1,6 @@ +use sqlx::MySql; use sqlx_test::test_type; use std::fmt::Debug; -use sqlx::MySql; // Transparent types are rust-side wrappers over DB types #[derive(PartialEq, Debug, sqlx::Type)] @@ -40,8 +40,4 @@ test_type!(weak_enum( "4" == Weak::Three )); -test_type!(strong_color_enum( - MySql, - Color, - "'green'" == Color::Green -)); +test_type!(strong_color_enum(MySql, Color, "'green'" == Color::Green)); diff --git a/tests/postgres-derives.rs b/tests/postgres-derives.rs index 0a3504d3..7c81ecf5 100644 --- a/tests/postgres-derives.rs +++ b/tests/postgres-derives.rs @@ -1,6 +1,6 @@ +use sqlx::Postgres; use sqlx_test::test_type; use std::fmt::Debug; -use sqlx::Postgres; // Transparent types are rust-side wrappers over DB types #[derive(PartialEq, Debug, sqlx::Type)] @@ -64,18 +64,21 @@ test_type!(record_pg_config( Postgres, PgConfig, // (CC,gcc) - "(SELECT ROW('CC', 'gcc')::pg_config)" == PgConfig { - name: "CC".to_owned(), - setting: Some("gcc".to_owned()), - }, + "(SELECT ROW('CC', 'gcc')::pg_config)" + == PgConfig { + name: "CC".to_owned(), + setting: Some("gcc".to_owned()), + }, // (CC,) - "(SELECT '(\"CC\",)'::pg_config)" == PgConfig { - name: "CC".to_owned(), - setting: None, - }, + "(SELECT '(\"CC\",)'::pg_config)" + == PgConfig { + name: "CC".to_owned(), + setting: None, + }, // (CC,"") - "(SELECT '(\"CC\",\"\")'::pg_config)" == PgConfig { - name: "CC".to_owned(), - setting: Some("".to_owned()), - } + "(SELECT '(\"CC\",\"\")'::pg_config)" + == PgConfig { + name: "CC".to_owned(), + setting: Some("".to_owned()), + } )); diff --git a/tests/postgres.rs b/tests/postgres.rs index 5d2b7d59..9f11d87f 100644 --- a/tests/postgres.rs +++ b/tests/postgres.rs @@ -1,7 +1,7 @@ use futures::TryStreamExt; -use sqlx_test::new; use sqlx::postgres::{PgPool, PgQueryAs, PgRow}; -use sqlx::{Postgres, Connection, Executor, Row}; +use sqlx::{Connection, Executor, Postgres, Row}; +use sqlx_test::new; use std::time::Duration; #[cfg_attr(feature = "runtime-async-std", async_std::test)] From fb5db48c529e6861f1ade712fce7d847d5b31420 Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Tue, 17 Mar 2020 19:42:23 -0700 Subject: [PATCH 25/25] Don't test custom records until we have some kind of fixtures --- tests/postgres-derives.rs | 58 ++++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 28 deletions(-) diff --git a/tests/postgres-derives.rs b/tests/postgres-derives.rs index 7c81ecf5..11689b1b 100644 --- a/tests/postgres-derives.rs +++ b/tests/postgres-derives.rs @@ -28,14 +28,15 @@ enum Strong { Three, } +// TODO: Figure out a good solution for custom type testing // Records must map to a custom type // Note that all types are types in Postgres -#[derive(PartialEq, Debug, sqlx::Type)] -#[sqlx(postgres(oid = 12184))] -struct PgConfig { - name: String, - setting: Option, -} +// #[derive(PartialEq, Debug, sqlx::Type)] +// #[sqlx(postgres(oid = 12184))] +// struct PgConfig { +// name: String, +// setting: Option, +// } test_type!(transparent( Postgres, @@ -60,25 +61,26 @@ test_type!(strong_enum( "'four'::text" == Strong::Three )); -test_type!(record_pg_config( - Postgres, - PgConfig, - // (CC,gcc) - "(SELECT ROW('CC', 'gcc')::pg_config)" - == PgConfig { - name: "CC".to_owned(), - setting: Some("gcc".to_owned()), - }, - // (CC,) - "(SELECT '(\"CC\",)'::pg_config)" - == PgConfig { - name: "CC".to_owned(), - setting: None, - }, - // (CC,"") - "(SELECT '(\"CC\",\"\")'::pg_config)" - == PgConfig { - name: "CC".to_owned(), - setting: Some("".to_owned()), - } -)); +// TODO: Figure out a good solution for custom type testing +// test_type!(record_pg_config( +// Postgres, +// PgConfig, +// // (CC,gcc) +// "(SELECT ROW('CC', 'gcc')::pg_config)" +// == PgConfig { +// name: "CC".to_owned(), +// setting: Some("gcc".to_owned()), +// }, +// // (CC,) +// "(SELECT '(\"CC\",)'::pg_config)" +// == PgConfig { +// name: "CC".to_owned(), +// setting: None, +// }, +// // (CC,"") +// "(SELECT '(\"CC\",\"\")'::pg_config)" +// == PgConfig { +// name: "CC".to_owned(), +// setting: Some("".to_owned()), +// } +// ));