diff --git a/sqlx-core/src/mysql/mod.rs b/sqlx-core/src/mysql/mod.rs index 7efddc38..1cd9cd2a 100644 --- a/sqlx-core/src/mysql/mod.rs +++ b/sqlx-core/src/mysql/mod.rs @@ -29,4 +29,4 @@ pub type MySqlPool = crate::pool::Pool; make_query_as!(MySqlQueryAs, MySql, MySqlRow); impl_map_row_for_row!(MySql, MySqlRow); impl_column_index_for_row!(MySql); -impl_from_row_for_tuples!(MySql, MySqlRow); +impl_from_row_for_tuples!(MySql, MySqlRow); \ No newline at end of file diff --git a/sqlx-core/src/mysql/protocol/type.rs b/sqlx-core/src/mysql/protocol/type.rs index 38ecb453..792ccb57 100644 --- a/sqlx-core/src/mysql/protocol/type.rs +++ b/sqlx-core/src/mysql/protocol/type.rs @@ -39,6 +39,9 @@ type_id_consts! { pub const VAR_CHAR: TypeId = TypeId(253); // or VAR_BINARY pub const TEXT: TypeId = TypeId(252); // or BLOB + // Enum + pub const ENUM: TypeId = TypeId(247); + // More Bytes pub const TINY_BLOB: TypeId = TypeId(249); pub const MEDIUM_BLOB: TypeId = TypeId(250); diff --git a/sqlx-core/src/mysql/types/mod.rs b/sqlx-core/src/mysql/types/mod.rs index 3c1311b0..343d39aa 100644 --- a/sqlx-core/src/mysql/types/mod.rs +++ b/sqlx-core/src/mysql/types/mod.rs @@ -64,6 +64,11 @@ impl MySqlTypeInfo { _ => None, } } + + #[doc(hidden)] + pub fn r#enum() -> Self { + Self::new(TypeId::ENUM) + } } impl Display for MySqlTypeInfo { @@ -85,6 +90,7 @@ impl TypeInfo for MySqlTypeInfo { | TypeId::TINY_BLOB | TypeId::MEDIUM_BLOB | TypeId::LONG_BLOB + | TypeId::ENUM if (self.is_binary == other.is_binary) && match other.id { TypeId::VAR_CHAR @@ -92,7 +98,8 @@ impl TypeInfo for MySqlTypeInfo { | TypeId::CHAR | TypeId::TINY_BLOB | TypeId::MEDIUM_BLOB - | TypeId::LONG_BLOB => true, + | TypeId::LONG_BLOB + | TypeId::ENUM => true, _ => false, } => diff --git a/sqlx-core/src/postgres/types/mod.rs b/sqlx-core/src/postgres/types/mod.rs index f017ac7f..ab676e9a 100644 --- a/sqlx-core/src/postgres/types/mod.rs +++ b/sqlx-core/src/postgres/types/mod.rs @@ -57,6 +57,10 @@ impl PgTypeInfo { _ => None, } } + + pub fn oid(&self) -> u32 { + self.id.0 + } } impl Display for PgTypeInfo { diff --git a/sqlx-macros/src/derives.rs b/sqlx-macros/src/derives.rs index b4e73a07..fdba0742 100644 --- a/sqlx-macros/src/derives.rs +++ b/sqlx-macros/src/derives.rs @@ -1,108 +1,847 @@ +use proc_macro2::Ident; use quote::quote; -use syn::{parse_quote, Data, DataStruct, DeriveInput, Fields, FieldsUnnamed}; +use std::iter::FromIterator; +use syn::punctuated::Punctuated; +use syn::token::Comma; +use syn::{ + parse_quote, Arm, Attribute, Block, Data, DataEnum, DataStruct, DeriveInput, Expr, Field, + Fields, FieldsNamed, FieldsUnnamed, Lit, Meta, MetaNameValue, NestedMeta, Variant,Stmt, +}; -pub(crate) fn expand_derive_encode(input: DeriveInput) -> syn::Result { - match &input.data { - Data::Struct(DataStruct { - fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), - .. - }) if unnamed.len() == 1 => { - let ident = &input.ident; - let ty = &unnamed.first().unwrap().ty; +macro_rules! assert_attribute { + ($e:expr, $err:expr, $input:expr) => { + if !$e { + return Err(syn::Error::new_spanned($input, $err)); + } + }; +} - // extract type generics - let generics = &input.generics; - let (_, ty_generics, _) = generics.split_for_impl(); +struct SqlxAttributes { + transparent: bool, + postgres_oid: Option, + repr: Option, + rename: Option, +} - // add db type for impl generics & where clause - let mut generics = generics.clone(); - generics.params.insert(0, parse_quote!(DB: sqlx::Database)); - generics - .make_where_clause() - .predicates - .push(parse_quote!(#ty: sqlx::encode::Encode)); - let (impl_generics, _, where_clause) = generics.split_for_impl(); +fn parse_attributes(input: &[Attribute]) -> syn::Result { + let mut transparent = None; + let mut postgres_oid = None; + let mut repr = None; + let mut rename = None; - Ok(quote!( - impl #impl_generics sqlx::encode::Encode for #ident #ty_generics #where_clause { - fn encode(&self, buf: &mut ::RawBuffer) { - sqlx::encode::Encode::encode(&self.0, buf) - } - fn encode_nullable(&self, buf: &mut ::RawBuffer) -> sqlx::encode::IsNull { - sqlx::encode::Encode::encode_nullable(&self.0, buf) - } - fn size_hint(&self) -> usize { - sqlx::encode::Encode::size_hint(&self.0) + macro_rules! fail { + ($t:expr, $m:expr) => { + return Err(syn::Error::new_spanned($t, $m)); + }; + } + + macro_rules! try_set { + ($i:ident, $v:expr, $t:expr) => { + match $i { + None => $i = Some($v), + Some(_) => fail!($t, "duplicate attribute"), + } + }; + } + + for attr in input { + let meta = attr + .parse_meta() + .map_err(|e| syn::Error::new_spanned(attr, e))?; + match meta { + Meta::List(list) if list.path.is_ident("sqlx") => { + for value in list.nested.iter() { + match value { + NestedMeta::Meta(meta) => match meta { + Meta::Path(p) if p.is_ident("transparent") => { + try_set!(transparent, true, value) + } + Meta::NameValue(MetaNameValue { + path, + lit: Lit::Str(val), + .. + }) if path.is_ident("rename") => try_set!(rename, val.value(), value), + Meta::List(list) if list.path.is_ident("postgres") => { + for value in list.nested.iter() { + match value { + NestedMeta::Meta(Meta::NameValue(MetaNameValue { + path, + lit: Lit::Int(val), + .. + })) if path.is_ident("oid") => { + try_set!(postgres_oid, val.base10_parse()?, value); + } + u => fail!(u, "unexpected value"), + } + } + } + + u => fail!(u, "unexpected attribute"), + }, + u => fail!(u, "unexpected attribute"), } } - )) + } + Meta::List(list) if list.path.is_ident("repr") => { + if list.nested.len() != 1 { + fail!(&list.nested, "expected one value") + } + match list.nested.first().unwrap() { + NestedMeta::Meta(Meta::Path(p)) if p.get_ident().is_some() => { + try_set!(repr, p.get_ident().unwrap().clone(), list); + } + u => fail!(u, "unexpected value"), + } + } + _ => {} } - _ => Err(syn::Error::new_spanned( - input, - "expected a tuple struct with a single field", - )), } + + Ok(SqlxAttributes { + transparent: transparent.unwrap_or(false), + postgres_oid, + repr, + rename, + }) } -pub(crate) fn expand_derive_decode(input: DeriveInput) -> syn::Result { +fn check_transparent_attributes(input: &DeriveInput, field: &Field) -> syn::Result<()> { + let attributes = parse_attributes(&input.attrs)?; + assert_attribute!( + attributes.transparent, + "expected #[sqlx(transparent)]", + input + ); + #[cfg(feature = "postgres")] + assert_attribute!( + attributes.postgres_oid.is_none(), + "unexpected #[sqlx(postgres(oid = ..))]", + input + ); + assert_attribute!( + attributes.rename.is_none(), + "unexpected #[sqlx(rename = ..)]", + field + ); + assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input); + let attributes = parse_attributes(&field.attrs)?; + assert_attribute!( + !attributes.transparent, + "unexpected #[sqlx(transparent)]", + field + ); + #[cfg(feature = "postgres")] + assert_attribute!( + attributes.postgres_oid.is_none(), + "unexpected #[sqlx(postgres(oid = ..))]", + field + ); + assert_attribute!( + attributes.rename.is_none(), + "unexpected #[sqlx(rename = ..)]", + field + ); + assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", field); + Ok(()) +} + +fn check_enum_attributes<'a>( + input: &'a DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let attributes = parse_attributes(&input.attrs)?; + assert_attribute!( + !attributes.transparent, + "unexpected #[sqlx(transparent)]", + input + ); + assert_attribute!( + attributes.rename.is_none(), + "unexpected #[sqlx(rename = ..)]", + input + ); + + for variant in variants { + let attributes = parse_attributes(&variant.attrs)?; + assert_attribute!( + !attributes.transparent, + "unexpected #[sqlx(transparent)]", + variant + ); + #[cfg(feature = "postgres")] + assert_attribute!( + attributes.postgres_oid.is_none(), + "unexpected #[sqlx(postgres(oid = ..))]", + variant + ); + assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", variant); + } + + Ok(attributes) +} + +fn check_weak_enum_attributes( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let attributes = check_enum_attributes(input, variants)?; + #[cfg(feature = "postgres")] + assert_attribute!( + attributes.postgres_oid.is_none(), + "unexpected #[sqlx(postgres(oid = ..))]", + input + ); + assert_attribute!(attributes.repr.is_some(), "expected #[repr(..)]", input); + for variant in variants { + let attributes = parse_attributes(&variant.attrs)?; + assert_attribute!( + attributes.rename.is_none(), + "unexpected #[sqlx(rename = ..)]", + variant + ); + } + Ok(attributes.repr.unwrap()) +} + +fn check_strong_enum_attributes( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let attributes = check_enum_attributes(input, variants)?; + #[cfg(feature = "postgres")] + assert_attribute!( + attributes.postgres_oid.is_some(), + "expected #[sqlx(postgres(oid = ..))]", + input + ); + assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input); + Ok(attributes) +} + +fn check_struct_attributes<'a>( + input: &'a DeriveInput, + fields: &Punctuated, +) -> syn::Result { + let attributes = parse_attributes(&input.attrs)?; + assert_attribute!( + !attributes.transparent, + "unexpected #[sqlx(transparent)]", + input + ); + #[cfg(feature = "postgres")] + assert_attribute!( + attributes.postgres_oid.is_some(), + "expected #[sqlx(postgres(oid = ..))]", + input + ); + assert_attribute!( + attributes.rename.is_none(), + "unexpected #[sqlx(rename = ..)]", + input + ); + assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input); + + for field in fields { + let attributes = parse_attributes(&field.attrs)?; + assert_attribute!( + !attributes.transparent, + "unexpected #[sqlx(transparent)]", + field + ); + #[cfg(feature = "postgres")] + assert_attribute!( + attributes.postgres_oid.is_none(), + "unexpected #[sqlx(postgres(oid = ..))]", + field + ); + assert_attribute!( + attributes.rename.is_none(), + "unexpected #[sqlx(rename = ..)]", + field + ); + assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", field); + } + + Ok(attributes) +} + +pub(crate) fn expand_derive_encode(input: &DeriveInput) -> syn::Result { + let args = parse_attributes(&input.attrs)?; + match &input.data { Data::Struct(DataStruct { fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), .. }) if unnamed.len() == 1 => { - let ident = &input.ident; - let ty = &unnamed.first().unwrap().ty; - - // extract type generics - let generics = &input.generics; - let (_, ty_generics, _) = generics.split_for_impl(); - - let mut impls = Vec::new(); - - if cfg!(feature = "postgres") { - let mut generics = generics.clone(); - generics.params.insert(0, parse_quote!('de)); - generics - .make_where_clause() - .predicates - .push(parse_quote!(#ty: sqlx::decode::Decode<'de, sqlx::Postgres>)); - - let (impl_generics, _, where_clause) = generics.split_for_impl(); - - impls.push(quote!( - impl #impl_generics sqlx::decode::Decode<'de, sqlx::Postgres> for #ident #ty_generics #where_clause { - fn decode(value: >::RawValue) -> sqlx::Result { - <#ty as sqlx::decode::Decode<'de, sqlx::Postgres>>::decode(value).map(Self) - } - } - )); - } - - if cfg!(feature = "mysql") { - let mut generics = generics.clone(); - generics.params.insert(0, parse_quote!('de)); - generics - .make_where_clause() - .predicates - .push(parse_quote!(#ty: sqlx::decode::Decode<'de, sqlx::MySql>)); - - let (impl_generics, _, where_clause) = generics.split_for_impl(); - - impls.push(quote!( - impl #impl_generics sqlx::decode::Decode<'de, sqlx::MySql> for #ident #ty_generics #where_clause { - fn decode(value: >::RawValue) -> sqlx::Result { - <#ty as sqlx::decode::Decode<'de, sqlx::MySql>>::decode(value).map(Self) - } - } - )); - } - - // panic!("{}", q) - Ok(quote!(#(#impls)*)) + expand_derive_encode_transparent(&input, unnamed.first().unwrap()) } + Data::Enum(DataEnum { variants, .. }) => match args.repr { + Some(_) => expand_derive_encode_weak_enum(input, variants), + None => expand_derive_encode_strong_enum(input, variants), + }, + Data::Struct(DataStruct { + fields: Fields::Named(FieldsNamed { named, .. }), + .. + }) => expand_derive_encode_struct(input, named), _ => Err(syn::Error::new_spanned( input, "expected a tuple struct with a single field", )), } } + +fn expand_derive_encode_transparent( + input: &DeriveInput, + field: &Field, +) -> syn::Result { + check_transparent_attributes(input, field)?; + + let ident = &input.ident; + let ty = &field.ty; + + // extract type generics + let generics = &input.generics; + let (_, ty_generics, _) = generics.split_for_impl(); + + // add db type for impl generics & where clause + let mut generics = generics.clone(); + generics.params.insert(0, parse_quote!(DB: sqlx::Database)); + generics + .make_where_clause() + .predicates + .push(parse_quote!(#ty: sqlx::encode::Encode)); + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + Ok(quote!( + impl #impl_generics sqlx::encode::Encode for #ident #ty_generics #where_clause { + fn encode(&self, buf: &mut std::vec::Vec) { + sqlx::encode::Encode::encode(&self.0, buf) + } + fn encode_nullable(&self, buf: &mut std::vec::Vec) -> sqlx::encode::IsNull { + sqlx::encode::Encode::encode_nullable(&self.0, buf) + } + fn size_hint(&self) -> usize { + sqlx::encode::Encode::size_hint(&self.0) + } + } + )) +} + +fn expand_derive_encode_weak_enum( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let repr = check_weak_enum_attributes(input, &variants)?; + + let ident = &input.ident; + + Ok(quote!( + impl sqlx::encode::Encode for #ident where #repr: sqlx::encode::Encode { + fn encode(&self, buf: &mut std::vec::Vec) { + sqlx::encode::Encode::encode(&(*self as #repr), buf) + } + fn encode_nullable(&self, buf: &mut std::vec::Vec) -> sqlx::encode::IsNull { + sqlx::encode::Encode::encode_nullable(&(*self as #repr), buf) + } + fn size_hint(&self) -> usize { + sqlx::encode::Encode::size_hint(&(*self as #repr)) + } + } + )) +} + +fn expand_derive_encode_strong_enum( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + check_strong_enum_attributes(input, &variants)?; + + let ident = &input.ident; + + let mut tts = proc_macro2::TokenStream::new(); + + if cfg!(feature = "mysql") { + let mut value_arms = Vec::new(); + for v in variants { + let id = &v.ident; + let attributes = parse_attributes(&v.attrs)?; + if let Some(rename) = attributes.rename { + value_arms.push(quote!(#ident :: #id => #rename,)); + } else { + let name = id.to_string(); + value_arms.push(quote!(#ident :: #id => #name,)); + } + } + + tts.extend(quote!( + impl sqlx::encode::Encode for #ident where str: sqlx::encode::Encode { + fn encode(&self, buf: &mut std::vec::Vec) { + let val = match self { + #(#value_arms)* + }; + >::encode(val, buf) + } + fn size_hint(&self) -> usize { + let val = match self { + #(#value_arms)* + }; + >::size_hint(val) + } + } + )); + } + + Ok(tts) +} + +fn expand_derive_encode_struct( + input: &DeriveInput, + fields: &Punctuated, +) -> syn::Result { + check_struct_attributes(input, &fields)?; + + let mut tts = proc_macro2::TokenStream::new(); + + if cfg!(feature = "postgres") { + let ident = &input.ident; + + let column_count = fields.len(); + + // extract type generics + let generics = &input.generics; + let (_, ty_generics, _) = generics.split_for_impl(); + + // add db type for impl generics & where clause + let mut generics = generics.clone(); + let predicates = &mut generics.make_where_clause().predicates; + for field in fields { + let ty = &field.ty; + predicates.push(parse_quote!(#ty: sqlx::encode::Encode)); + predicates.push(parse_quote!(sqlx::Postgres: sqlx::types::HasSqlType<#ty>)); + } + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + let mut writes: Vec = Vec::new(); + for field in fields { + let id = &field.ident; + let ty = &field.ty; + writes.push(parse_quote!({ + // write oid + let info = >::type_info(); + buf.extend(&info.oid().to_be_bytes()); + + // write zeros for length + buf.extend(&[0; 4]); + + let start = buf.len(); + sqlx::encode::Encode::::encode(&self. #id, buf); + let end = buf.len(); + let size = end - start; + + // replaces zeros with actual length + buf[start-4..start].copy_from_slice(&(size as u32).to_be_bytes()); + })); + } + + let mut sizes: Vec = Vec::new(); + for field in fields { + let id = &field.ident; + let ty = &field.ty; + sizes.push( + parse_quote!(<#ty as sqlx::encode::Encode>::size_hint(&self. #id)), + ); + } + + tts.extend(quote!( + impl #impl_generics sqlx::encode::Encode for #ident #ty_generics #where_clause { + fn encode(&self, buf: &mut std::vec::Vec) { + buf.extend(&(#column_count as u32).to_be_bytes()); + #(#writes)* + } + fn size_hint(&self) -> usize { + 4 + #column_count * (4 + 4) + #(#sizes)+* + } + } + )); + } + + Ok(tts) +} + +pub(crate) fn expand_derive_decode(input: &DeriveInput) -> syn::Result { + let attrs = parse_attributes(&input.attrs)?; + match &input.data { + Data::Struct(DataStruct { + fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), + .. + }) if unnamed.len() == 1 => { + expand_derive_decode_transparent(input, unnamed.first().unwrap()) + } + Data::Enum(DataEnum { variants, .. }) => match attrs.repr { + Some(_) => expand_derive_decode_weak_enum(input, variants), + None => expand_derive_decode_strong_enum(input, variants), + }, + Data::Struct(DataStruct { + fields: Fields::Named(FieldsNamed { named, .. }), + .. + }) => expand_derive_decode_struct(input, named), + _ => Err(syn::Error::new_spanned( + input, + "expected a tuple struct with a single field", + )), + } +} + +fn expand_derive_decode_transparent( + input: &DeriveInput, + field: &Field, +) -> syn::Result { + check_transparent_attributes(input, field)?; + + let ident = &input.ident; + let ty = &field.ty; + + // extract type generics + let generics = &input.generics; + let (_, ty_generics, _) = generics.split_for_impl(); + + // add db type for impl generics & where clause + let mut generics = generics.clone(); + generics.params.insert(0, parse_quote!(DB: sqlx::Database)); + generics + .make_where_clause() + .predicates + .push(parse_quote!(#ty: sqlx::decode::Decode)); + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + Ok(quote!( + impl #impl_generics sqlx::decode::Decode for #ident #ty_generics #where_clause { + fn decode(raw: &[u8]) -> std::result::Result { + <#ty as sqlx::decode::Decode>::decode(raw).map(Self) + } + fn decode_null() -> std::result::Result { + <#ty as sqlx::decode::Decode>::decode_null().map(Self) + } + fn decode_nullable(raw: std::option::Option<&[u8]>) -> std::result::Result { + <#ty as sqlx::decode::Decode>::decode_nullable(raw).map(Self) + } + } + )) +} + +fn expand_derive_decode_weak_enum( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let repr = check_weak_enum_attributes(input, &variants)?; + + let ident = &input.ident; + let arms = variants + .iter() + .map(|v| { + let id = &v.ident; + parse_quote!(_ if (#ident :: #id as #repr) == val => Ok(#ident :: #id),) + }) + .collect::>(); + + Ok(quote!( + impl sqlx::decode::Decode for #ident where #repr: sqlx::decode::Decode { + fn decode(raw: &[u8]) -> std::result::Result { + let val = <#repr as sqlx::decode::Decode>::decode(raw)?; + match val { + #(#arms)* + _ => Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid value"))) + } + } + } + )) +} + +fn expand_derive_decode_strong_enum( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + check_strong_enum_attributes(input, &variants)?; + + let ident = &input.ident; + + let mut value_arms = Vec::new(); + for v in variants { + let id = &v.ident; + let attributes = parse_attributes(&v.attrs)?; + if let Some(rename) = attributes.rename { + value_arms.push(quote!(#rename => Ok(#ident :: #id),)); + } else { + let name = id.to_string(); + value_arms.push(quote!(#name => Ok(#ident :: #id),)); + } + } + + // TODO: prevent heap allocation + Ok(quote!( + impl sqlx::decode::Decode for #ident where String: sqlx::decode::Decode { + fn decode(buf: &[u8]) -> std::result::Result { + let val = >::decode(buf)?; + match val.as_str() { + #(#value_arms)* + _ => Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid value"))) + } + } + } + )) +} + +fn expand_derive_decode_struct( + input: &DeriveInput, + fields: &Punctuated, +) -> syn::Result { + check_struct_attributes(input, fields)?; + + let ident = &input.ident; + + let column_count = fields.len(); + + // extract type generics + let generics = &input.generics; + let (_, ty_generics, _) = generics.split_for_impl(); + + // add db type for impl generics & where clause + let mut generics = generics.clone(); + let predicates = &mut generics.make_where_clause().predicates; + for field in fields { + let ty = &field.ty; + predicates.push(parse_quote!(#ty: sqlx::decode::Decode)); + predicates.push(parse_quote!(sqlx::Postgres: sqlx::types::HasSqlType<#ty>)); + } + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + let mut reads: Vec> = Vec::new(); + let mut names: Vec = Vec::new(); + for field in fields { + let id = &field.ident; + names.push(id.clone().unwrap()); + let ty = &field.ty; + reads.push(parse_quote!( + if buf.len() < 8 { + return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Not enough data sent"))); + } + + let oid = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[0..4]).unwrap()); + if oid != >::type_info().oid() { + return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid oid"))); + } + + let len = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[4..8]).unwrap()) as usize; + + if buf.len() < 8 + len { + return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Not enough data sent"))); + } + + let raw = &buf[8..8+len]; + let #id = <#ty as sqlx::decode::Decode>::decode(raw)?; + + let buf = &buf[8+len..]; + )); + } + let reads = reads.into_iter().flatten(); + + Ok(quote!( + impl #impl_generics sqlx::decode::Decode for #ident#ty_generics #where_clause { + fn decode(buf: &[u8]) -> std::result::Result { + if buf.len() < 4 { + return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Not enough data sent"))); + } + + let column_count = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[..4]).unwrap()) as usize; + if column_count != #column_count { + return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid column count"))); + } + let buf = &buf[4..]; + + #(#reads)* + + if !buf.is_empty() { + return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new(format!("Too much data sent ({} bytes left)", buf.len())))); + } + + Ok(#ident { + #(#names),* + }) + } + } + )) +} + +pub(crate) fn expand_derive_has_sql_type( + input: &DeriveInput, +) -> syn::Result { + let attrs = parse_attributes(&input.attrs)?; + match &input.data { + Data::Struct(DataStruct { + fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), + .. + }) if unnamed.len() == 1 => { + expand_derive_has_sql_type_transparent(input, unnamed.first().unwrap()) + } + Data::Enum(DataEnum { variants, .. }) => match attrs.repr { + Some(_) => expand_derive_has_sql_type_weak_enum(input, variants), + None => expand_derive_has_sql_type_strong_enum(input, variants), + }, + Data::Struct(DataStruct { + fields: Fields::Named(FieldsNamed { named, .. }), + .. + }) => expand_derive_has_sql_type_struct(input, named), + _ => Err(syn::Error::new_spanned( + input, + "expected a tuple struct with a single field", + )), + } +} + +fn expand_derive_has_sql_type_transparent( + input: &DeriveInput, + field: &Field, +) -> syn::Result { + check_transparent_attributes(input, field)?; + + let ident = &input.ident; + let ty = &field.ty; + + // extract type generics + let generics = &input.generics; + let (impl_generics, ty_generics, _) = generics.split_for_impl(); + + // add db type for clause + let mut generics = generics.clone(); + generics + .make_where_clause() + .predicates + .push(parse_quote!(Self: sqlx::types::HasSqlType<#ty>)); + let (_, _, where_clause) = generics.split_for_impl(); + + let mut tts = proc_macro2::TokenStream::new(); + + if cfg!(feature = "mysql") { + tts.extend(quote!( + impl #impl_generics sqlx::types::HasSqlType< #ident #ty_generics > for sqlx::MySql #where_clause { + fn type_info() -> Self::TypeInfo { + >::type_info() + } + } + )); + } + + if cfg!(feature = "postgres") { + tts.extend(quote!( + impl #impl_generics sqlx::types::HasSqlType< #ident #ty_generics > for sqlx::Postgres #where_clause { + fn type_info() -> Self::TypeInfo { + >::type_info() + } + } + )); + } + + Ok(tts) +} + +fn expand_derive_has_sql_type_weak_enum( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let repr = check_weak_enum_attributes(input, variants)?; + + let ident = &input.ident; + let mut tts = proc_macro2::TokenStream::new(); + + if cfg!(feature = "mysql") { + tts.extend(quote!( + impl sqlx::types::HasSqlType< #ident > for sqlx::MySql where Self: sqlx::types::HasSqlType< #repr > { + fn type_info() -> Self::TypeInfo { + >::type_info() + } + } + )); + } + + if cfg!(feature = "postgres") { + tts.extend(quote!( + impl sqlx::types::HasSqlType< #ident > for sqlx::Postgres where Self: sqlx::types::HasSqlType< #repr > { + fn type_info() -> Self::TypeInfo { + >::type_info() + } + } + )); + } + + Ok(tts) +} + +fn expand_derive_has_sql_type_strong_enum( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let attributes = check_strong_enum_attributes(input, variants)?; + + let ident = &input.ident; + let mut tts = proc_macro2::TokenStream::new(); + + if cfg!(feature = "mysql") { + tts.extend(quote!( + impl sqlx::types::HasSqlType< #ident > for sqlx::MySql { + fn type_info() -> Self::TypeInfo { + sqlx::mysql::MySqlTypeInfo::r#enum() + } + } + )); + } + + if cfg!(feature = "postgres") { + let oid = attributes.postgres_oid.unwrap(); + tts.extend(quote!( + impl sqlx::types::HasSqlType< #ident > for sqlx::Postgres { + fn type_info() -> Self::TypeInfo { + sqlx::postgres::PgTypeInfo::with_oid(#oid) + } + } + )); + } + + Ok(tts) +} + +fn expand_derive_has_sql_type_struct( + input: &DeriveInput, + fields: &Punctuated, +) -> syn::Result { + let attributes = check_struct_attributes(input, fields)?; + + let ident = &input.ident; + let mut tts = proc_macro2::TokenStream::new(); + + if cfg!(feature = "postgres") { + let oid = attributes.postgres_oid.unwrap(); + tts.extend(quote!( + impl sqlx::types::HasSqlType< #ident > for sqlx::Postgres { + fn type_info() -> Self::TypeInfo { + sqlx::postgres::PgTypeInfo::with_oid(#oid) + } + } + )); + } + + Ok(tts) +} + +pub(crate) fn expand_derive_type(input: &DeriveInput) -> syn::Result { + let encode_tts = expand_derive_encode(input)?; + let decode_tts = expand_derive_decode(input)?; + let has_sql_type_tts = expand_derive_has_sql_type(input)?; + + let combined = proc_macro2::TokenStream::from_iter( + encode_tts + .into_iter() + .chain(decode_tts) + .chain(has_sql_type_tts), + ); + Ok(combined) +} diff --git a/sqlx-macros/src/lib.rs b/sqlx-macros/src/lib.rs index 0bed307f..f74c775e 100644 --- a/sqlx-macros/src/lib.rs +++ b/sqlx-macros/src/lib.rs @@ -151,19 +151,37 @@ pub fn query_file_as(input: TokenStream) -> TokenStream { async_macro!(db, input: QueryAsMacroInput => expand_query_file_as(input, db)) } -#[proc_macro_derive(Encode)] +#[proc_macro_derive(Encode, attributes(sqlx))] pub fn derive_encode(tokenstream: TokenStream) -> TokenStream { let input = syn::parse_macro_input!(tokenstream as syn::DeriveInput); - match derives::expand_derive_encode(input) { + match derives::expand_derive_encode(&input) { Ok(ts) => ts.into(), Err(e) => e.to_compile_error().into(), } } -#[proc_macro_derive(Decode)] +#[proc_macro_derive(Decode, attributes(sqlx))] pub fn derive_decode(tokenstream: TokenStream) -> TokenStream { let input = syn::parse_macro_input!(tokenstream as syn::DeriveInput); - match derives::expand_derive_decode(input) { + match derives::expand_derive_decode(&input) { + Ok(ts) => ts.into(), + Err(e) => e.to_compile_error().into(), + } +} + +#[proc_macro_derive(HasSqlType, attributes(sqlx))] +pub fn derive_has_sql_type(tokenstream: TokenStream) -> TokenStream { + let input = syn::parse_macro_input!(tokenstream as syn::DeriveInput); + match derives::expand_derive_has_sql_type(&input) { + Ok(ts) => ts.into(), + Err(e) => e.to_compile_error().into(), + } +} + +#[proc_macro_derive(Type, attributes(sqlx))] +pub fn derive_type(tokenstream: TokenStream) -> TokenStream { + let input = syn::parse_macro_input!(tokenstream as syn::DeriveInput); + match derives::expand_derive_type(&input) { Ok(ts) => ts.into(), Err(e) => e.to_compile_error().into(), } diff --git a/src/lib.rs b/src/lib.rs index d1dc6aaa..3342b09b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -40,6 +40,9 @@ pub use sqlx_core::sqlite::{self, Sqlite, SqliteConnection, SqlitePool}; #[doc(hidden)] pub extern crate sqlx_macros; +#[cfg(feature = "macros")] +pub use sqlx_macros::Type; + #[cfg(feature = "macros")] mod macros; diff --git a/src/types.rs b/src/types.rs new file mode 100644 index 00000000..6eb3cab6 --- /dev/null +++ b/src/types.rs @@ -0,0 +1,6 @@ +//! Traits linking Rust types to SQL types. + +pub use sqlx_core::types::*; + +#[cfg(feature = "macros")] +pub use sqlx_macros::HasSqlType; diff --git a/tests/derives.rs b/tests/derives.rs index d5120031..6ec2cc90 100644 --- a/tests/derives.rs +++ b/tests/derives.rs @@ -1,77 +1,311 @@ use sqlx::decode::Decode; use sqlx::encode::Encode; +use sqlx::types::{HasSqlType, TypeInfo}; +use std::fmt::Debug; -#[derive(PartialEq, Debug, Encode, Decode)] -struct Foo(i32); +#[derive(PartialEq, Debug, Encode, Decode, HasSqlType)] +#[sqlx(transparent)] +struct Transparent(i32); -#[test] -#[cfg(feature = "postgres")] -fn encode_with_postgres() { - use sqlx_core::postgres::Postgres; +#[derive(PartialEq, Debug, Clone, Copy, Encode, Decode, HasSqlType)] +#[repr(i32)] +#[allow(dead_code)] +enum Weak { + One, + Two, + Three, +} - let example = Foo(0x1122_3344); +#[derive(PartialEq, Debug, Encode, Decode, HasSqlType)] +#[sqlx(postgres(oid = 10101010))] +#[allow(dead_code)] +enum Strong { + One, + Two, + #[sqlx(rename = "four")] + Three, +} - let mut encoded = Vec::new(); - let mut encoded_orig = Vec::new(); - - Encode::::encode(&example, &mut encoded); - Encode::::encode(&example.0, &mut encoded_orig); - - assert_eq!(encoded, encoded_orig); +#[derive(PartialEq, Debug, Encode, Decode, HasSqlType)] +#[sqlx(postgres(oid = 20202020))] +#[allow(dead_code)] +struct Struct { + field1: String, + field2: i64, + field3: bool, } #[test] #[cfg(feature = "mysql")] -fn encode_with_mysql() { - use sqlx_core::mysql::MySql; - - let example = Foo(0x1122_3344); - - let mut encoded = Vec::new(); - let mut encoded_orig = Vec::new(); - - Encode::::encode(&example, &mut encoded); - Encode::::encode(&example.0, &mut encoded_orig); - - assert_eq!(encoded, encoded_orig); -} - -#[test] -#[cfg(feature = "mysql")] -fn decode_mysql() { - decode_with_db(); +fn encode_transparent_mysql() { + encode_transparent::(); } #[test] #[cfg(feature = "postgres")] -fn decode_postgres() { - decode_with_db(); +fn encode_transparent_postgres() { + encode_transparent::(); } -#[cfg(feature = "postgres")] -fn decode_with_db() +#[allow(dead_code)] +fn encode_transparent() where - Foo: for<'de> Decode<'de, sqlx::Postgres> + Encode, + Transparent: Encode, + i32: Encode, { - let example = Foo(0x1122_3344); + let example = Transparent(0x1122_3344); + + let mut encoded = Vec::new(); + let mut encoded_orig = Vec::new(); + + Encode::::encode(&example, &mut encoded); + Encode::::encode(&example.0, &mut encoded_orig); + + assert_eq!(encoded, encoded_orig); +} + +#[test] +#[cfg(feature = "mysql")] +fn encode_weak_enum_mysql() { + encode_weak_enum::(); +} + +#[test] +#[cfg(feature = "postgres")] +fn encode_weak_enum_postgres() { + encode_weak_enum::(); +} + +#[allow(dead_code)] +fn encode_weak_enum() +where + Weak: Encode, + i32: Encode, +{ + for example in [Weak::One, Weak::Two, Weak::Three].iter() { + let mut encoded = Vec::new(); + let mut encoded_orig = Vec::new(); + + Encode::::encode(example, &mut encoded); + Encode::::encode(&(*example as i32), &mut encoded_orig); + + assert_eq!(encoded, encoded_orig); + } +} + +#[test] +#[cfg(feature = "mysql")] +fn encode_strong_enum_mysql() { + encode_strong_enum::(); +} + +#[test] +#[cfg(feature = "postgres")] +fn encode_strong_enum_postgres() { + encode_strong_enum::(); +} + +#[allow(dead_code)] +fn encode_strong_enum() +where + Strong: Encode, + str: Encode, +{ + for (example, name) in [ + (Strong::One, "One"), + (Strong::Two, "Two"), + (Strong::Three, "four"), + ] + .iter() + { + let mut encoded = Vec::new(); + let mut encoded_orig = Vec::new(); + + Encode::::encode(example, &mut encoded); + Encode::::encode(name, &mut encoded_orig); + + assert_eq!(encoded, encoded_orig); + } +} + +#[test] +#[cfg(feature = "postgres")] +fn encode_struct_postgres() { + let field1 = "Foo".to_string(); + let field2 = 3; + let field3 = false; + + let example = Struct { + field1: field1.clone(), + field2, + field3, + }; let mut encoded = Vec::new(); Encode::::encode(&example, &mut encoded); - let decoded = Foo::decode(Some(sqlx::postgres::PgValue::Binary(&encoded))).unwrap(); - assert_eq!(example, decoded); + let string_oid = >::type_info().oid(); + let i64_oid = >::type_info().oid(); + let bool_oid = >::type_info().oid(); + + // 3 columns + assert_eq!(&[0, 0, 0, 3], &encoded[..4]); + let encoded = &encoded[4..]; + + // check field1 (string) + assert_eq!(&string_oid.to_be_bytes(), &encoded[0..4]); + assert_eq!(&(field1.len() as u32).to_be_bytes(), &encoded[4..8]); + assert_eq!(field1.as_bytes(), &encoded[8..8 + field1.len()]); + let encoded = &encoded[8 + field1.len()..]; + + // check field2 (i64) + assert_eq!(&i64_oid.to_be_bytes(), &encoded[0..4]); + assert_eq!(&8u32.to_be_bytes(), &encoded[4..8]); + assert_eq!(field2.to_be_bytes(), &encoded[8..16]); + let encoded = &encoded[16..]; + + // check field3 (bool) + assert_eq!(&bool_oid.to_be_bytes(), &encoded[0..4]); + assert_eq!(&1u32.to_be_bytes(), &encoded[4..8]); + assert_eq!(field3, encoded[8] != 0); + let encoded = &encoded[9..]; + + assert!(encoded.is_empty()); + + let string_size = >::size_hint(&field1); + let i64_size = >::size_hint(&field2); + let bool_size = >::size_hint(&field3); + + assert_eq!( + 4 + 3 * (4 + 4) + string_size + i64_size + bool_size, + example.size_hint() + ); } +#[test] #[cfg(feature = "mysql")] -fn decode_with_db() -where - Foo: for<'de> Decode<'de, sqlx::MySql> + Encode, -{ - let example = Foo(0x1122_3344); +fn decode_transparent_mysql() { + decode_with_db::(Transparent(0x1122_3344)); +} +#[test] +#[cfg(feature = "postgres")] +fn decode_transparent_postgres() { + decode_with_db::(Transparent(0x1122_3344)); +} + +#[test] +#[cfg(feature = "mysql")] +fn decode_weak_enum_mysql() { + decode_with_db::(Weak::One); + decode_with_db::(Weak::Two); + decode_with_db::(Weak::Three); +} + +#[test] +#[cfg(feature = "postgres")] +fn decode_weak_enum_postgres() { + decode_with_db::(Weak::One); + decode_with_db::(Weak::Two); + decode_with_db::(Weak::Three); +} + +#[test] +#[cfg(feature = "mysql")] +fn decode_strong_enum_mysql() { + decode_with_db::(Strong::One); + decode_with_db::(Strong::Two); + decode_with_db::(Strong::Three); +} + +#[test] +#[cfg(feature = "postgres")] +fn decode_strong_enum_postgres() { + decode_with_db::(Strong::One); + decode_with_db::(Strong::Two); + decode_with_db::(Strong::Three); +} + +#[test] +#[cfg(feature = "postgres")] +fn decode_struct_postgres() { + decode_with_db::(Struct { + field1: "Foo".to_string(), + field2: 3, + field3: true, + }); +} + +#[allow(dead_code)] +fn decode_with_db + Encode + PartialEq + Debug>(example: V) { let mut encoded = Vec::new(); - Encode::::encode(&example, &mut encoded); + Encode::::encode(&example, &mut encoded); - let decoded = Foo::decode(Some(sqlx::mysql::MySqlValue::Binary(&encoded))).unwrap(); + let decoded = V::decode(&encoded).unwrap(); assert_eq!(example, decoded); } + +#[test] +#[cfg(feature = "mysql")] +fn has_sql_type_transparent_mysql() { + has_sql_type_transparent::(); +} + +#[test] +#[cfg(feature = "postgres")] +fn has_sql_type_transparent_postgres() { + has_sql_type_transparent::(); +} + +#[allow(dead_code)] +fn has_sql_type_transparent() +where + DB: HasSqlType + HasSqlType, +{ + let info: DB::TypeInfo = >::type_info(); + let info_orig: DB::TypeInfo = >::type_info(); + assert!(info.compatible(&info_orig)); +} + +#[test] +#[cfg(feature = "mysql")] +fn has_sql_type_weak_enum_mysql() { + has_sql_type_weak_enum::(); +} + +#[test] +#[cfg(feature = "postgres")] +fn has_sql_type_weak_enum_postgres() { + has_sql_type_weak_enum::(); +} + +#[allow(dead_code)] +fn has_sql_type_weak_enum() +where + DB: HasSqlType + HasSqlType, +{ + let info: DB::TypeInfo = >::type_info(); + let info_orig: DB::TypeInfo = >::type_info(); + assert!(info.compatible(&info_orig)); +} + +#[test] +#[cfg(feature = "mysql")] +fn has_sql_type_strong_enum_mysql() { + let info: sqlx::mysql::MySqlTypeInfo = >::type_info(); + assert!(info.compatible(&sqlx::mysql::MySqlTypeInfo::r#enum())) +} + +#[test] +#[cfg(feature = "postgres")] +fn has_sql_type_strong_enum_postgres() { + let info: sqlx::postgres::PgTypeInfo = >::type_info(); + assert!(info.compatible(&sqlx::postgres::PgTypeInfo::with_oid(10101010))) +} + +#[test] +#[cfg(feature = "postgres")] +fn has_sql_type_struct_postgres() { + let info: sqlx::postgres::PgTypeInfo = >::type_info(); + assert!(info.compatible(&sqlx::postgres::PgTypeInfo::with_oid(20202020))) +}