diff --git a/sqlx-macros/src/derives.rs b/sqlx-macros/src/derives.rs deleted file mode 100644 index 8ddc6203..00000000 --- a/sqlx-macros/src/derives.rs +++ /dev/null @@ -1,846 +0,0 @@ -use proc_macro2::Ident; -use quote::quote; -use std::iter::FromIterator; -use syn::punctuated::Punctuated; -use syn::token::Comma; -use syn::{ - parse_quote, Arm, Attribute, Block, Data, DataEnum, DataStruct, DeriveInput, Expr, Field, - Fields, FieldsNamed, FieldsUnnamed, Lit, Meta, MetaNameValue, NestedMeta, Stmt, Variant, -}; - -macro_rules! assert_attribute { - ($e:expr, $err:expr, $input:expr) => { - if !$e { - return Err(syn::Error::new_spanned($input, $err)); - } - }; -} - -struct SqlxAttributes { - transparent: bool, - postgres_oid: Option, - repr: Option, - rename: Option, -} - -fn parse_attributes(input: &[Attribute]) -> syn::Result { - let mut transparent = None; - let mut postgres_oid = None; - let mut repr = None; - let mut rename = None; - - macro_rules! fail { - ($t:expr, $m:expr) => { - return Err(syn::Error::new_spanned($t, $m)); - }; - } - - macro_rules! try_set { - ($i:ident, $v:expr, $t:expr) => { - match $i { - None => $i = Some($v), - Some(_) => fail!($t, "duplicate attribute"), - } - }; - } - - for attr in input { - let meta = attr - .parse_meta() - .map_err(|e| syn::Error::new_spanned(attr, e))?; - match meta { - Meta::List(list) if list.path.is_ident("sqlx") => { - for value in list.nested.iter() { - match value { - NestedMeta::Meta(meta) => match meta { - Meta::Path(p) if p.is_ident("transparent") => { - try_set!(transparent, true, value) - } - Meta::NameValue(MetaNameValue { - path, - lit: Lit::Str(val), - .. - }) if path.is_ident("rename") => try_set!(rename, val.value(), value), - Meta::List(list) if list.path.is_ident("postgres") => { - for value in list.nested.iter() { - match value { - NestedMeta::Meta(Meta::NameValue(MetaNameValue { - path, - lit: Lit::Int(val), - .. - })) if path.is_ident("oid") => { - try_set!(postgres_oid, val.base10_parse()?, value); - } - u => fail!(u, "unexpected value"), - } - } - } - - u => fail!(u, "unexpected attribute"), - }, - u => fail!(u, "unexpected attribute"), - } - } - } - Meta::List(list) if list.path.is_ident("repr") => { - if list.nested.len() != 1 { - fail!(&list.nested, "expected one value") - } - match list.nested.first().unwrap() { - NestedMeta::Meta(Meta::Path(p)) if p.get_ident().is_some() => { - try_set!(repr, p.get_ident().unwrap().clone(), list); - } - u => fail!(u, "unexpected value"), - } - } - _ => {} - } - } - - Ok(SqlxAttributes { - transparent: transparent.unwrap_or(false), - postgres_oid, - repr, - rename, - }) -} - -fn check_transparent_attributes(input: &DeriveInput, field: &Field) -> syn::Result<()> { - let attributes = parse_attributes(&input.attrs)?; - assert_attribute!( - attributes.transparent, - "expected #[sqlx(transparent)]", - input - ); - #[cfg(feature = "postgres")] - assert_attribute!( - attributes.postgres_oid.is_none(), - "unexpected #[sqlx(postgres(oid = ..))]", - input - ); - assert_attribute!( - attributes.rename.is_none(), - "unexpected #[sqlx(rename = ..)]", - field - ); - assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input); - let attributes = parse_attributes(&field.attrs)?; - assert_attribute!( - !attributes.transparent, - "unexpected #[sqlx(transparent)]", - field - ); - #[cfg(feature = "postgres")] - assert_attribute!( - attributes.postgres_oid.is_none(), - "unexpected #[sqlx(postgres(oid = ..))]", - field - ); - assert_attribute!( - attributes.rename.is_none(), - "unexpected #[sqlx(rename = ..)]", - field - ); - assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", field); - Ok(()) -} - -fn check_enum_attributes<'a>( - input: &'a DeriveInput, - variants: &Punctuated, -) -> syn::Result { - let attributes = parse_attributes(&input.attrs)?; - assert_attribute!( - !attributes.transparent, - "unexpected #[sqlx(transparent)]", - input - ); - assert_attribute!( - attributes.rename.is_none(), - "unexpected #[sqlx(rename = ..)]", - input - ); - - for variant in variants { - let attributes = parse_attributes(&variant.attrs)?; - assert_attribute!( - !attributes.transparent, - "unexpected #[sqlx(transparent)]", - variant - ); - #[cfg(feature = "postgres")] - assert_attribute!( - attributes.postgres_oid.is_none(), - "unexpected #[sqlx(postgres(oid = ..))]", - variant - ); - assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", variant); - } - - Ok(attributes) -} - -fn check_weak_enum_attributes( - input: &DeriveInput, - variants: &Punctuated, -) -> syn::Result { - let attributes = check_enum_attributes(input, variants)?; - #[cfg(feature = "postgres")] - assert_attribute!( - attributes.postgres_oid.is_none(), - "unexpected #[sqlx(postgres(oid = ..))]", - input - ); - assert_attribute!(attributes.repr.is_some(), "expected #[repr(..)]", input); - for variant in variants { - let attributes = parse_attributes(&variant.attrs)?; - assert_attribute!( - attributes.rename.is_none(), - "unexpected #[sqlx(rename = ..)]", - variant - ); - } - Ok(attributes.repr.unwrap()) -} - -fn check_strong_enum_attributes( - input: &DeriveInput, - variants: &Punctuated, -) -> syn::Result { - let attributes = check_enum_attributes(input, variants)?; - #[cfg(feature = "postgres")] - assert_attribute!( - attributes.postgres_oid.is_some(), - "expected #[sqlx(postgres(oid = ..))]", - input - ); - assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input); - Ok(attributes) -} - -fn check_struct_attributes<'a>( - input: &'a DeriveInput, - fields: &Punctuated, -) -> syn::Result { - let attributes = parse_attributes(&input.attrs)?; - assert_attribute!( - !attributes.transparent, - "unexpected #[sqlx(transparent)]", - input - ); - #[cfg(feature = "postgres")] - assert_attribute!( - attributes.postgres_oid.is_some(), - "expected #[sqlx(postgres(oid = ..))]", - input - ); - assert_attribute!( - attributes.rename.is_none(), - "unexpected #[sqlx(rename = ..)]", - input - ); - assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input); - - for field in fields { - let attributes = parse_attributes(&field.attrs)?; - assert_attribute!( - !attributes.transparent, - "unexpected #[sqlx(transparent)]", - field - ); - #[cfg(feature = "postgres")] - assert_attribute!( - attributes.postgres_oid.is_none(), - "unexpected #[sqlx(postgres(oid = ..))]", - field - ); - assert_attribute!( - attributes.rename.is_none(), - "unexpected #[sqlx(rename = ..)]", - field - ); - assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", field); - } - - Ok(attributes) -} - -pub(crate) fn expand_derive_encode(input: &DeriveInput) -> syn::Result { - let args = parse_attributes(&input.attrs)?; - - match &input.data { - Data::Struct(DataStruct { - fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), - .. - }) if unnamed.len() == 1 => { - expand_derive_encode_transparent(&input, unnamed.first().unwrap()) - } - Data::Enum(DataEnum { variants, .. }) => match args.repr { - Some(_) => expand_derive_encode_weak_enum(input, variants), - None => expand_derive_encode_strong_enum(input, variants), - }, - Data::Struct(DataStruct { - fields: Fields::Named(FieldsNamed { named, .. }), - .. - }) => expand_derive_encode_struct(input, named), - _ => Err(syn::Error::new_spanned( - input, - "expected a tuple struct with a single field", - )), - } -} - -fn expand_derive_encode_transparent( - input: &DeriveInput, - field: &Field, -) -> syn::Result { - check_transparent_attributes(input, field)?; - - let ident = &input.ident; - let ty = &field.ty; - - // extract type generics - let generics = &input.generics; - let (_, ty_generics, _) = generics.split_for_impl(); - - // add db type for impl generics & where clause - let mut generics = generics.clone(); - generics.params.insert(0, parse_quote!(DB: sqlx::Database)); - generics - .make_where_clause() - .predicates - .push(parse_quote!(#ty: sqlx::encode::Encode)); - let (impl_generics, _, where_clause) = generics.split_for_impl(); - - Ok(quote!( - impl #impl_generics sqlx::encode::Encode for #ident #ty_generics #where_clause { - fn encode(&self, buf: &mut std::vec::Vec) { - sqlx::encode::Encode::encode(&self.0, buf) - } - fn encode_nullable(&self, buf: &mut std::vec::Vec) -> sqlx::encode::IsNull { - sqlx::encode::Encode::encode_nullable(&self.0, buf) - } - fn size_hint(&self) -> usize { - sqlx::encode::Encode::size_hint(&self.0) - } - } - )) -} - -fn expand_derive_encode_weak_enum( - input: &DeriveInput, - variants: &Punctuated, -) -> syn::Result { - let repr = check_weak_enum_attributes(input, &variants)?; - - let ident = &input.ident; - - Ok(quote!( - impl sqlx::encode::Encode for #ident where #repr: sqlx::encode::Encode { - fn encode(&self, buf: &mut std::vec::Vec) { - sqlx::encode::Encode::encode(&(*self as #repr), buf) - } - fn encode_nullable(&self, buf: &mut std::vec::Vec) -> sqlx::encode::IsNull { - sqlx::encode::Encode::encode_nullable(&(*self as #repr), buf) - } - fn size_hint(&self) -> usize { - sqlx::encode::Encode::size_hint(&(*self as #repr)) - } - } - )) -} - -fn expand_derive_encode_strong_enum( - input: &DeriveInput, - variants: &Punctuated, -) -> syn::Result { - check_strong_enum_attributes(input, &variants)?; - - let ident = &input.ident; - - let mut value_arms = Vec::new(); - for v in variants { - let id = &v.ident; - let attributes = parse_attributes(&v.attrs)?; - if let Some(rename) = attributes.rename { - value_arms.push(quote!(#ident :: #id => #rename,)); - } else { - let name = id.to_string(); - value_arms.push(quote!(#ident :: #id => #name,)); - } - } - - Ok(quote!( - impl sqlx::encode::Encode for #ident where str: sqlx::encode::Encode { - fn encode(&self, buf: &mut std::vec::Vec) { - let val = match self { - #(#value_arms)* - }; - >::encode(val, buf) - } - fn size_hint(&self) -> usize { - let val = match self { - #(#value_arms)* - }; - >::size_hint(val) - } - } - )) -} - -fn expand_derive_encode_struct( - input: &DeriveInput, - fields: &Punctuated, -) -> syn::Result { - check_struct_attributes(input, &fields)?; - - let mut tts = proc_macro2::TokenStream::new(); - - if cfg!(feature = "postgres") { - let ident = &input.ident; - - let column_count = fields.len(); - - // extract type generics - let generics = &input.generics; - let (_, ty_generics, _) = generics.split_for_impl(); - - // add db type for impl generics & where clause - let mut generics = generics.clone(); - let predicates = &mut generics.make_where_clause().predicates; - for field in fields { - let ty = &field.ty; - predicates.push(parse_quote!(#ty: sqlx::encode::Encode)); - predicates.push(parse_quote!(sqlx::Postgres: sqlx::types::HasSqlType<#ty>)); - } - let (impl_generics, _, where_clause) = generics.split_for_impl(); - - let mut writes: Vec = Vec::new(); - for field in fields { - let id = &field.ident; - let ty = &field.ty; - writes.push(parse_quote!({ - // write oid - let info = >::type_info(); - buf.extend(&info.oid().to_be_bytes()); - - // write zeros for length - buf.extend(&[0; 4]); - - let start = buf.len(); - sqlx::encode::Encode::::encode(&self. #id, buf); - let end = buf.len(); - let size = end - start; - - // replaces zeros with actual length - buf[start-4..start].copy_from_slice(&(size as u32).to_be_bytes()); - })); - } - - let mut sizes: Vec = Vec::new(); - for field in fields { - let id = &field.ident; - let ty = &field.ty; - sizes.push( - parse_quote!(<#ty as sqlx::encode::Encode>::size_hint(&self. #id)), - ); - } - - tts.extend(quote!( - impl #impl_generics sqlx::encode::Encode for #ident #ty_generics #where_clause { - fn encode(&self, buf: &mut std::vec::Vec) { - buf.extend(&(#column_count as u32).to_be_bytes()); - #(#writes)* - } - fn size_hint(&self) -> usize { - 4 + #column_count * (4 + 4) + #(#sizes)+* - } - } - )); - } - - Ok(tts) -} - -pub(crate) fn expand_derive_decode(input: &DeriveInput) -> syn::Result { - let attrs = parse_attributes(&input.attrs)?; - match &input.data { - Data::Struct(DataStruct { - fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), - .. - }) if unnamed.len() == 1 => { - expand_derive_decode_transparent(input, unnamed.first().unwrap()) - } - Data::Enum(DataEnum { variants, .. }) => match attrs.repr { - Some(_) => expand_derive_decode_weak_enum(input, variants), - None => expand_derive_decode_strong_enum(input, variants), - }, - Data::Struct(DataStruct { - fields: Fields::Named(FieldsNamed { named, .. }), - .. - }) => expand_derive_decode_struct(input, named), - _ => Err(syn::Error::new_spanned( - input, - "expected a tuple struct with a single field", - )), - } -} - -fn expand_derive_decode_transparent( - input: &DeriveInput, - field: &Field, -) -> syn::Result { - check_transparent_attributes(input, field)?; - - let ident = &input.ident; - let ty = &field.ty; - - // extract type generics - let generics = &input.generics; - let (_, ty_generics, _) = generics.split_for_impl(); - - // add db type for impl generics & where clause - let mut generics = generics.clone(); - generics.params.insert(0, parse_quote!(DB: sqlx::Database)); - generics - .make_where_clause() - .predicates - .push(parse_quote!(#ty: sqlx::decode::Decode)); - let (impl_generics, _, where_clause) = generics.split_for_impl(); - - Ok(quote!( - impl #impl_generics sqlx::decode::Decode for #ident #ty_generics #where_clause { - fn decode(raw: &[u8]) -> std::result::Result { - <#ty as sqlx::decode::Decode>::decode(raw).map(Self) - } - fn decode_null() -> std::result::Result { - <#ty as sqlx::decode::Decode>::decode_null().map(Self) - } - fn decode_nullable(raw: std::option::Option<&[u8]>) -> std::result::Result { - <#ty as sqlx::decode::Decode>::decode_nullable(raw).map(Self) - } - } - )) -} - -fn expand_derive_decode_weak_enum( - input: &DeriveInput, - variants: &Punctuated, -) -> syn::Result { - let repr = check_weak_enum_attributes(input, &variants)?; - - let ident = &input.ident; - let arms = variants - .iter() - .map(|v| { - let id = &v.ident; - parse_quote!(_ if (#ident :: #id as #repr) == val => Ok(#ident :: #id),) - }) - .collect::>(); - - Ok(quote!( - impl sqlx::decode::Decode for #ident where #repr: sqlx::decode::Decode { - fn decode(raw: &[u8]) -> std::result::Result { - let val = <#repr as sqlx::decode::Decode>::decode(raw)?; - match val { - #(#arms)* - _ => Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid value"))) - } - } - } - )) -} - -fn expand_derive_decode_strong_enum( - input: &DeriveInput, - variants: &Punctuated, -) -> syn::Result { - check_strong_enum_attributes(input, &variants)?; - - let ident = &input.ident; - - let mut value_arms = Vec::new(); - for v in variants { - let id = &v.ident; - let attributes = parse_attributes(&v.attrs)?; - if let Some(rename) = attributes.rename { - value_arms.push(quote!(#rename => Ok(#ident :: #id),)); - } else { - let name = id.to_string(); - value_arms.push(quote!(#name => Ok(#ident :: #id),)); - } - } - - // TODO: prevent heap allocation - Ok(quote!( - impl sqlx::decode::Decode for #ident where std::string::String: sqlx::decode::Decode { - fn decode(buf: &[u8]) -> std::result::Result { - let val = >::decode(buf)?; - match val.as_str() { - #(#value_arms)* - _ => Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid value"))) - } - } - } - )) -} - -fn expand_derive_decode_struct( - input: &DeriveInput, - fields: &Punctuated, -) -> syn::Result { - check_struct_attributes(input, fields)?; - - let mut tts = proc_macro2::TokenStream::new(); - - if cfg!(feature = "postgres") { - let ident = &input.ident; - - let column_count = fields.len(); - - // extract type generics - let generics = &input.generics; - let (_, ty_generics, _) = generics.split_for_impl(); - - // add db type for impl generics & where clause - let mut generics = generics.clone(); - let predicates = &mut generics.make_where_clause().predicates; - for field in fields { - let ty = &field.ty; - predicates.push(parse_quote!(#ty: sqlx::decode::Decode)); - predicates.push(parse_quote!(sqlx::Postgres: sqlx::types::HasSqlType<#ty>)); - } - let (impl_generics, _, where_clause) = generics.split_for_impl(); - - let mut reads: Vec> = Vec::new(); - let mut names: Vec = Vec::new(); - for field in fields { - let id = &field.ident; - names.push(id.clone().unwrap()); - let ty = &field.ty; - reads.push(parse_quote!( - if buf.len() < 8 { - return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Not enough data sent"))); - } - - let oid = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[0..4]).unwrap()); - if oid != >::type_info().oid() { - return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid oid"))); - } - - let len = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[4..8]).unwrap()) as usize; - - if buf.len() < 8 + len { - return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Not enough data sent"))); - } - - let raw = &buf[8..8+len]; - let #id = <#ty as sqlx::decode::Decode>::decode(raw)?; - - let buf = &buf[8+len..]; - )); - } - let reads = reads.into_iter().flatten(); - - tts.extend(quote!( - impl #impl_generics sqlx::decode::Decode for #ident#ty_generics #where_clause { - fn decode(buf: &[u8]) -> std::result::Result { - if buf.len() < 4 { - return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Not enough data sent"))); - } - - let column_count = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[..4]).unwrap()) as usize; - if column_count != #column_count { - return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid column count"))); - } - let buf = &buf[4..]; - - #(#reads)* - - if !buf.is_empty() { - return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new(format!("Too much data sent ({} bytes left)", buf.len())))); - } - - Ok(#ident { - #(#names),* - }) - } - } - )) - } - Ok(tts) -} - -pub(crate) fn expand_derive_has_sql_type( - input: &DeriveInput, -) -> syn::Result { - let attrs = parse_attributes(&input.attrs)?; - match &input.data { - Data::Struct(DataStruct { - fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), - .. - }) if unnamed.len() == 1 => { - expand_derive_has_sql_type_transparent(input, unnamed.first().unwrap()) - } - Data::Enum(DataEnum { variants, .. }) => match attrs.repr { - Some(_) => expand_derive_has_sql_type_weak_enum(input, variants), - None => expand_derive_has_sql_type_strong_enum(input, variants), - }, - Data::Struct(DataStruct { - fields: Fields::Named(FieldsNamed { named, .. }), - .. - }) => expand_derive_has_sql_type_struct(input, named), - _ => Err(syn::Error::new_spanned( - input, - "expected a tuple struct with a single field", - )), - } -} - -fn expand_derive_has_sql_type_transparent( - input: &DeriveInput, - field: &Field, -) -> syn::Result { - check_transparent_attributes(input, field)?; - - let ident = &input.ident; - let ty = &field.ty; - - // extract type generics - let generics = &input.generics; - let (impl_generics, ty_generics, _) = generics.split_for_impl(); - - // add db type for clause - let mut generics = generics.clone(); - generics - .make_where_clause() - .predicates - .push(parse_quote!(Self: sqlx::types::HasSqlType<#ty>)); - let (_, _, where_clause) = generics.split_for_impl(); - - let mut tts = proc_macro2::TokenStream::new(); - - if cfg!(feature = "mysql") { - tts.extend(quote!( - impl #impl_generics sqlx::types::HasSqlType< #ident #ty_generics > for sqlx::MySql #where_clause { - fn type_info() -> Self::TypeInfo { - >::type_info() - } - } - )); - } - - if cfg!(feature = "postgres") { - tts.extend(quote!( - impl #impl_generics sqlx::types::HasSqlType< #ident #ty_generics > for sqlx::Postgres #where_clause { - fn type_info() -> Self::TypeInfo { - >::type_info() - } - } - )); - } - - Ok(tts) -} - -fn expand_derive_has_sql_type_weak_enum( - input: &DeriveInput, - variants: &Punctuated, -) -> syn::Result { - let repr = check_weak_enum_attributes(input, variants)?; - - let ident = &input.ident; - let mut tts = proc_macro2::TokenStream::new(); - - if cfg!(feature = "mysql") { - tts.extend(quote!( - impl sqlx::types::HasSqlType< #ident > for sqlx::MySql where Self: sqlx::types::HasSqlType< #repr > { - fn type_info() -> Self::TypeInfo { - >::type_info() - } - } - )); - } - - if cfg!(feature = "postgres") { - tts.extend(quote!( - impl sqlx::types::HasSqlType< #ident > for sqlx::Postgres where Self: sqlx::types::HasSqlType< #repr > { - fn type_info() -> Self::TypeInfo { - >::type_info() - } - } - )); - } - - Ok(tts) -} - -fn expand_derive_has_sql_type_strong_enum( - input: &DeriveInput, - variants: &Punctuated, -) -> syn::Result { - let attributes = check_strong_enum_attributes(input, variants)?; - - let ident = &input.ident; - let mut tts = proc_macro2::TokenStream::new(); - - if cfg!(feature = "mysql") { - tts.extend(quote!( - impl sqlx::types::HasSqlType< #ident > for sqlx::MySql { - fn type_info() -> Self::TypeInfo { - sqlx::mysql::MySqlTypeInfo::r#enum() - } - } - )); - } - - if cfg!(feature = "postgres") { - let oid = attributes.postgres_oid.unwrap(); - tts.extend(quote!( - impl sqlx::types::HasSqlType< #ident > for sqlx::Postgres { - fn type_info() -> Self::TypeInfo { - sqlx::postgres::PgTypeInfo::with_oid(#oid) - } - } - )); - } - - Ok(tts) -} - -fn expand_derive_has_sql_type_struct( - input: &DeriveInput, - fields: &Punctuated, -) -> syn::Result { - let attributes = check_struct_attributes(input, fields)?; - - let ident = &input.ident; - let mut tts = proc_macro2::TokenStream::new(); - - if cfg!(feature = "postgres") { - let oid = attributes.postgres_oid.unwrap(); - tts.extend(quote!( - impl sqlx::types::HasSqlType< #ident > for sqlx::Postgres { - fn type_info() -> Self::TypeInfo { - sqlx::postgres::PgTypeInfo::with_oid(#oid) - } - } - )); - } - - Ok(tts) -} - -pub(crate) fn expand_derive_type(input: &DeriveInput) -> syn::Result { - let encode_tts = expand_derive_encode(input)?; - let decode_tts = expand_derive_decode(input)?; - let has_sql_type_tts = expand_derive_has_sql_type(input)?; - - let combined = proc_macro2::TokenStream::from_iter( - encode_tts - .into_iter() - .chain(decode_tts) - .chain(has_sql_type_tts), - ); - Ok(combined) -} diff --git a/sqlx-macros/src/derives/attributes.rs b/sqlx-macros/src/derives/attributes.rs new file mode 100644 index 00000000..72df6991 --- /dev/null +++ b/sqlx-macros/src/derives/attributes.rs @@ -0,0 +1,261 @@ +use proc_macro2::Ident; +use syn::punctuated::Punctuated; +use syn::token::Comma; +use syn::{Attribute, DeriveInput, Field, Lit, Meta, MetaNameValue, NestedMeta, Variant}; + +macro_rules! assert_attribute { + ($e:expr, $err:expr, $input:expr) => { + if !$e { + return Err(syn::Error::new_spanned($input, $err)); + } + }; +} + +pub struct SqlxAttributes { + pub transparent: bool, + pub postgres_oid: Option, + pub repr: Option, + pub rename: Option, +} + +pub fn parse_attributes(input: &[Attribute]) -> syn::Result { + let mut transparent = None; + let mut postgres_oid = None; + let mut repr = None; + let mut rename = None; + + macro_rules! fail { + ($t:expr, $m:expr) => { + return Err(syn::Error::new_spanned($t, $m)); + }; + } + + macro_rules! try_set { + ($i:ident, $v:expr, $t:expr) => { + match $i { + None => $i = Some($v), + Some(_) => fail!($t, "duplicate attribute"), + } + }; + } + + for attr in input { + let meta = attr + .parse_meta() + .map_err(|e| syn::Error::new_spanned(attr, e))?; + match meta { + Meta::List(list) if list.path.is_ident("sqlx") => { + for value in list.nested.iter() { + match value { + NestedMeta::Meta(meta) => match meta { + Meta::Path(p) if p.is_ident("transparent") => { + try_set!(transparent, true, value) + } + Meta::NameValue(MetaNameValue { + path, + lit: Lit::Str(val), + .. + }) if path.is_ident("rename") => try_set!(rename, val.value(), value), + Meta::List(list) if list.path.is_ident("postgres") => { + for value in list.nested.iter() { + match value { + NestedMeta::Meta(Meta::NameValue(MetaNameValue { + path, + lit: Lit::Int(val), + .. + })) if path.is_ident("oid") => { + try_set!(postgres_oid, val.base10_parse()?, value); + } + u => fail!(u, "unexpected value"), + } + } + } + + u => fail!(u, "unexpected attribute"), + }, + u => fail!(u, "unexpected attribute"), + } + } + } + Meta::List(list) if list.path.is_ident("repr") => { + if list.nested.len() != 1 { + fail!(&list.nested, "expected one value") + } + match list.nested.first().unwrap() { + NestedMeta::Meta(Meta::Path(p)) if p.get_ident().is_some() => { + try_set!(repr, p.get_ident().unwrap().clone(), list); + } + u => fail!(u, "unexpected value"), + } + } + _ => {} + } + } + + Ok(SqlxAttributes { + transparent: transparent.unwrap_or(false), + postgres_oid, + repr, + rename, + }) +} + +pub fn check_transparent_attributes(input: &DeriveInput, field: &Field) -> syn::Result<()> { + let attributes = parse_attributes(&input.attrs)?; + assert_attribute!( + attributes.transparent, + "expected #[sqlx(transparent)]", + input + ); + #[cfg(feature = "postgres")] + assert_attribute!( + attributes.postgres_oid.is_none(), + "unexpected #[sqlx(postgres(oid = ..))]", + input + ); + assert_attribute!( + attributes.rename.is_none(), + "unexpected #[sqlx(rename = ..)]", + field + ); + assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input); + let attributes = parse_attributes(&field.attrs)?; + assert_attribute!( + !attributes.transparent, + "unexpected #[sqlx(transparent)]", + field + ); + #[cfg(feature = "postgres")] + assert_attribute!( + attributes.postgres_oid.is_none(), + "unexpected #[sqlx(postgres(oid = ..))]", + field + ); + assert_attribute!( + attributes.rename.is_none(), + "unexpected #[sqlx(rename = ..)]", + field + ); + assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", field); + Ok(()) +} + +pub fn check_enum_attributes<'a>( + input: &'a DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let attributes = parse_attributes(&input.attrs)?; + assert_attribute!( + !attributes.transparent, + "unexpected #[sqlx(transparent)]", + input + ); + assert_attribute!( + attributes.rename.is_none(), + "unexpected #[sqlx(rename = ..)]", + input + ); + + for variant in variants { + let attributes = parse_attributes(&variant.attrs)?; + assert_attribute!( + !attributes.transparent, + "unexpected #[sqlx(transparent)]", + variant + ); + #[cfg(feature = "postgres")] + assert_attribute!( + attributes.postgres_oid.is_none(), + "unexpected #[sqlx(postgres(oid = ..))]", + variant + ); + assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", variant); + } + + Ok(attributes) +} + +pub fn check_weak_enum_attributes( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let attributes = check_enum_attributes(input, variants)?; + #[cfg(feature = "postgres")] + assert_attribute!( + attributes.postgres_oid.is_none(), + "unexpected #[sqlx(postgres(oid = ..))]", + input + ); + assert_attribute!(attributes.repr.is_some(), "expected #[repr(..)]", input); + for variant in variants { + let attributes = parse_attributes(&variant.attrs)?; + assert_attribute!( + attributes.rename.is_none(), + "unexpected #[sqlx(rename = ..)]", + variant + ); + } + Ok(attributes.repr.unwrap()) +} + +pub fn check_strong_enum_attributes( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let attributes = check_enum_attributes(input, variants)?; + #[cfg(feature = "postgres")] + assert_attribute!( + attributes.postgres_oid.is_some(), + "expected #[sqlx(postgres(oid = ..))]", + input + ); + assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input); + Ok(attributes) +} + +pub fn check_struct_attributes<'a>( + input: &'a DeriveInput, + fields: &Punctuated, +) -> syn::Result { + let attributes = parse_attributes(&input.attrs)?; + assert_attribute!( + !attributes.transparent, + "unexpected #[sqlx(transparent)]", + input + ); + #[cfg(feature = "postgres")] + assert_attribute!( + attributes.postgres_oid.is_some(), + "expected #[sqlx(postgres(oid = ..))]", + input + ); + assert_attribute!( + attributes.rename.is_none(), + "unexpected #[sqlx(rename = ..)]", + input + ); + assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input); + + for field in fields { + let attributes = parse_attributes(&field.attrs)?; + assert_attribute!( + !attributes.transparent, + "unexpected #[sqlx(transparent)]", + field + ); + #[cfg(feature = "postgres")] + assert_attribute!( + attributes.postgres_oid.is_none(), + "unexpected #[sqlx(postgres(oid = ..))]", + field + ); + assert_attribute!( + attributes.rename.is_none(), + "unexpected #[sqlx(rename = ..)]", + field + ); + assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", field); + } + + Ok(attributes) +} diff --git a/sqlx-macros/src/derives/decode.rs b/sqlx-macros/src/derives/decode.rs new file mode 100644 index 00000000..2aa779f8 --- /dev/null +++ b/sqlx-macros/src/derives/decode.rs @@ -0,0 +1,221 @@ +use super::attributes::{ + check_strong_enum_attributes, check_struct_attributes, check_transparent_attributes, + check_weak_enum_attributes, parse_attributes, +}; +use proc_macro2::Ident; +use quote::quote; +use syn::punctuated::Punctuated; +use syn::token::Comma; +use syn::{ + parse_quote, Arm, Data, DataEnum, DataStruct, DeriveInput, Field, Fields, FieldsNamed, + FieldsUnnamed, Stmt, Variant, +}; + +pub fn expand_derive_decode(input: &DeriveInput) -> syn::Result { + let attrs = parse_attributes(&input.attrs)?; + match &input.data { + Data::Struct(DataStruct { + fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), + .. + }) if unnamed.len() == 1 => { + expand_derive_decode_transparent(input, unnamed.first().unwrap()) + } + Data::Enum(DataEnum { variants, .. }) => match attrs.repr { + Some(_) => expand_derive_decode_weak_enum(input, variants), + None => expand_derive_decode_strong_enum(input, variants), + }, + Data::Struct(DataStruct { + fields: Fields::Named(FieldsNamed { named, .. }), + .. + }) => expand_derive_decode_struct(input, named), + _ => Err(syn::Error::new_spanned( + input, + "expected a tuple struct with a single field", + )), + } +} + +fn expand_derive_decode_transparent( + input: &DeriveInput, + field: &Field, +) -> syn::Result { + check_transparent_attributes(input, field)?; + + let ident = &input.ident; + let ty = &field.ty; + + // extract type generics + let generics = &input.generics; + let (_, ty_generics, _) = generics.split_for_impl(); + + // add db type for impl generics & where clause + let mut generics = generics.clone(); + generics.params.insert(0, parse_quote!(DB: sqlx::Database)); + generics + .make_where_clause() + .predicates + .push(parse_quote!(#ty: sqlx::decode::Decode)); + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + Ok(quote!( + impl #impl_generics sqlx::decode::Decode for #ident #ty_generics #where_clause { + fn decode(raw: &[u8]) -> std::result::Result { + <#ty as sqlx::decode::Decode>::decode(raw).map(Self) + } + fn decode_null() -> std::result::Result { + <#ty as sqlx::decode::Decode>::decode_null().map(Self) + } + fn decode_nullable(raw: std::option::Option<&[u8]>) -> std::result::Result { + <#ty as sqlx::decode::Decode>::decode_nullable(raw).map(Self) + } + } + )) +} + +fn expand_derive_decode_weak_enum( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let repr = check_weak_enum_attributes(input, &variants)?; + + let ident = &input.ident; + let arms = variants + .iter() + .map(|v| { + let id = &v.ident; + parse_quote!(_ if (#ident :: #id as #repr) == val => Ok(#ident :: #id),) + }) + .collect::>(); + + Ok(quote!( + impl sqlx::decode::Decode for #ident where #repr: sqlx::decode::Decode { + fn decode(raw: &[u8]) -> std::result::Result { + let val = <#repr as sqlx::decode::Decode>::decode(raw)?; + match val { + #(#arms)* + _ => Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid value"))) + } + } + } + )) +} + +fn expand_derive_decode_strong_enum( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + check_strong_enum_attributes(input, &variants)?; + + let ident = &input.ident; + + let mut value_arms = Vec::new(); + for v in variants { + let id = &v.ident; + let attributes = parse_attributes(&v.attrs)?; + if let Some(rename) = attributes.rename { + value_arms.push(quote!(#rename => Ok(#ident :: #id),)); + } else { + let name = id.to_string(); + value_arms.push(quote!(#name => Ok(#ident :: #id),)); + } + } + + // TODO: prevent heap allocation + Ok(quote!( + impl sqlx::decode::Decode for #ident where std::string::String: sqlx::decode::Decode { + fn decode(buf: &[u8]) -> std::result::Result { + let val = >::decode(buf)?; + match val.as_str() { + #(#value_arms)* + _ => Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid value"))) + } + } + } + )) +} + +fn expand_derive_decode_struct( + input: &DeriveInput, + fields: &Punctuated, +) -> syn::Result { + check_struct_attributes(input, fields)?; + + let mut tts = proc_macro2::TokenStream::new(); + + if cfg!(feature = "postgres") { + let ident = &input.ident; + + let column_count = fields.len(); + + // extract type generics + let generics = &input.generics; + let (_, ty_generics, _) = generics.split_for_impl(); + + // add db type for impl generics & where clause + let mut generics = generics.clone(); + let predicates = &mut generics.make_where_clause().predicates; + for field in fields { + let ty = &field.ty; + predicates.push(parse_quote!(#ty: sqlx::decode::Decode)); + predicates.push(parse_quote!(sqlx::Postgres: sqlx::types::HasSqlType<#ty>)); + } + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + let mut reads: Vec> = Vec::new(); + let mut names: Vec = Vec::new(); + for field in fields { + let id = &field.ident; + names.push(id.clone().unwrap()); + let ty = &field.ty; + reads.push(parse_quote!( + if buf.len() < 8 { + return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Not enough data sent"))); + } + + let oid = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[0..4]).unwrap()); + if oid != >::type_info().oid() { + return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid oid"))); + } + + let len = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[4..8]).unwrap()) as usize; + + if buf.len() < 8 + len { + return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Not enough data sent"))); + } + + let raw = &buf[8..8+len]; + let #id = <#ty as sqlx::decode::Decode>::decode(raw)?; + + let buf = &buf[8+len..]; + )); + } + let reads = reads.into_iter().flatten(); + + tts.extend(quote!( + impl #impl_generics sqlx::decode::Decode for #ident#ty_generics #where_clause { + fn decode(buf: &[u8]) -> std::result::Result { + if buf.len() < 4 { + return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Not enough data sent"))); + } + + let column_count = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[..4]).unwrap()) as usize; + if column_count != #column_count { + return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid column count"))); + } + let buf = &buf[4..]; + + #(#reads)* + + if !buf.is_empty() { + return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new(format!("Too much data sent ({} bytes left)", buf.len())))); + } + + Ok(#ident { + #(#names),* + }) + } + } + )) + } + Ok(tts) +} diff --git a/sqlx-macros/src/derives/encode.rs b/sqlx-macros/src/derives/encode.rs new file mode 100644 index 00000000..03c91d00 --- /dev/null +++ b/sqlx-macros/src/derives/encode.rs @@ -0,0 +1,208 @@ +use super::attributes::{ + check_strong_enum_attributes, check_struct_attributes, check_transparent_attributes, + check_weak_enum_attributes, parse_attributes, +}; +use quote::quote; +use syn::punctuated::Punctuated; +use syn::token::Comma; +use syn::{ + parse_quote, Block, Data, DataEnum, DataStruct, DeriveInput, Expr, Field, Fields, FieldsNamed, + FieldsUnnamed, Variant, +}; + +pub fn expand_derive_encode(input: &DeriveInput) -> syn::Result { + let args = parse_attributes(&input.attrs)?; + + match &input.data { + Data::Struct(DataStruct { + fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), + .. + }) if unnamed.len() == 1 => { + expand_derive_encode_transparent(&input, unnamed.first().unwrap()) + } + Data::Enum(DataEnum { variants, .. }) => match args.repr { + Some(_) => expand_derive_encode_weak_enum(input, variants), + None => expand_derive_encode_strong_enum(input, variants), + }, + Data::Struct(DataStruct { + fields: Fields::Named(FieldsNamed { named, .. }), + .. + }) => expand_derive_encode_struct(input, named), + _ => Err(syn::Error::new_spanned( + input, + "expected a tuple struct with a single field", + )), + } +} + +fn expand_derive_encode_transparent( + input: &DeriveInput, + field: &Field, +) -> syn::Result { + check_transparent_attributes(input, field)?; + + let ident = &input.ident; + let ty = &field.ty; + + // extract type generics + let generics = &input.generics; + let (_, ty_generics, _) = generics.split_for_impl(); + + // add db type for impl generics & where clause + let mut generics = generics.clone(); + generics.params.insert(0, parse_quote!(DB: sqlx::Database)); + generics + .make_where_clause() + .predicates + .push(parse_quote!(#ty: sqlx::encode::Encode)); + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + Ok(quote!( + impl #impl_generics sqlx::encode::Encode for #ident #ty_generics #where_clause { + fn encode(&self, buf: &mut std::vec::Vec) { + sqlx::encode::Encode::encode(&self.0, buf) + } + fn encode_nullable(&self, buf: &mut std::vec::Vec) -> sqlx::encode::IsNull { + sqlx::encode::Encode::encode_nullable(&self.0, buf) + } + fn size_hint(&self) -> usize { + sqlx::encode::Encode::size_hint(&self.0) + } + } + )) +} + +fn expand_derive_encode_weak_enum( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let repr = check_weak_enum_attributes(input, &variants)?; + + let ident = &input.ident; + + Ok(quote!( + impl sqlx::encode::Encode for #ident where #repr: sqlx::encode::Encode { + fn encode(&self, buf: &mut std::vec::Vec) { + sqlx::encode::Encode::encode(&(*self as #repr), buf) + } + fn encode_nullable(&self, buf: &mut std::vec::Vec) -> sqlx::encode::IsNull { + sqlx::encode::Encode::encode_nullable(&(*self as #repr), buf) + } + fn size_hint(&self) -> usize { + sqlx::encode::Encode::size_hint(&(*self as #repr)) + } + } + )) +} + +fn expand_derive_encode_strong_enum( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + check_strong_enum_attributes(input, &variants)?; + + let ident = &input.ident; + + let mut value_arms = Vec::new(); + for v in variants { + let id = &v.ident; + let attributes = parse_attributes(&v.attrs)?; + if let Some(rename) = attributes.rename { + value_arms.push(quote!(#ident :: #id => #rename,)); + } else { + let name = id.to_string(); + value_arms.push(quote!(#ident :: #id => #name,)); + } + } + + Ok(quote!( + impl sqlx::encode::Encode for #ident where str: sqlx::encode::Encode { + fn encode(&self, buf: &mut std::vec::Vec) { + let val = match self { + #(#value_arms)* + }; + >::encode(val, buf) + } + fn size_hint(&self) -> usize { + let val = match self { + #(#value_arms)* + }; + >::size_hint(val) + } + } + )) +} + +fn expand_derive_encode_struct( + input: &DeriveInput, + fields: &Punctuated, +) -> syn::Result { + check_struct_attributes(input, &fields)?; + + let mut tts = proc_macro2::TokenStream::new(); + + if cfg!(feature = "postgres") { + let ident = &input.ident; + + let column_count = fields.len(); + + // extract type generics + let generics = &input.generics; + let (_, ty_generics, _) = generics.split_for_impl(); + + // add db type for impl generics & where clause + let mut generics = generics.clone(); + let predicates = &mut generics.make_where_clause().predicates; + for field in fields { + let ty = &field.ty; + predicates.push(parse_quote!(#ty: sqlx::encode::Encode)); + predicates.push(parse_quote!(sqlx::Postgres: sqlx::types::HasSqlType<#ty>)); + } + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + let mut writes: Vec = Vec::new(); + for field in fields { + let id = &field.ident; + let ty = &field.ty; + writes.push(parse_quote!({ + // write oid + let info = >::type_info(); + buf.extend(&info.oid().to_be_bytes()); + + // write zeros for length + buf.extend(&[0; 4]); + + let start = buf.len(); + sqlx::encode::Encode::::encode(&self. #id, buf); + let end = buf.len(); + let size = end - start; + + // replaces zeros with actual length + buf[start-4..start].copy_from_slice(&(size as u32).to_be_bytes()); + })); + } + + let mut sizes: Vec = Vec::new(); + for field in fields { + let id = &field.ident; + let ty = &field.ty; + sizes.push( + parse_quote!(<#ty as sqlx::encode::Encode>::size_hint(&self. #id)), + ); + } + + tts.extend(quote!( + impl #impl_generics sqlx::encode::Encode for #ident #ty_generics #where_clause { + fn encode(&self, buf: &mut std::vec::Vec) { + buf.extend(&(#column_count as u32).to_be_bytes()); + #(#writes)* + } + fn size_hint(&self) -> usize { + 4 + #column_count * (4 + 4) + #(#sizes)+* + } + } + )); + } + + Ok(tts) +} diff --git a/sqlx-macros/src/derives/has_sql_type.rs b/sqlx-macros/src/derives/has_sql_type.rs new file mode 100644 index 00000000..64da9ae1 --- /dev/null +++ b/sqlx-macros/src/derives/has_sql_type.rs @@ -0,0 +1,169 @@ +use super::attributes::{ + check_strong_enum_attributes, check_struct_attributes, check_transparent_attributes, + check_weak_enum_attributes, parse_attributes, +}; +use quote::quote; +use syn::punctuated::Punctuated; +use syn::token::Comma; +use syn::{ + parse_quote, Data, DataEnum, DataStruct, DeriveInput, Field, Fields, FieldsNamed, + FieldsUnnamed, Variant, +}; + +pub fn expand_derive_has_sql_type(input: &DeriveInput) -> syn::Result { + let attrs = parse_attributes(&input.attrs)?; + match &input.data { + Data::Struct(DataStruct { + fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), + .. + }) if unnamed.len() == 1 => { + expand_derive_has_sql_type_transparent(input, unnamed.first().unwrap()) + } + Data::Enum(DataEnum { variants, .. }) => match attrs.repr { + Some(_) => expand_derive_has_sql_type_weak_enum(input, variants), + None => expand_derive_has_sql_type_strong_enum(input, variants), + }, + Data::Struct(DataStruct { + fields: Fields::Named(FieldsNamed { named, .. }), + .. + }) => expand_derive_has_sql_type_struct(input, named), + _ => Err(syn::Error::new_spanned( + input, + "expected a tuple struct with a single field", + )), + } +} + +fn expand_derive_has_sql_type_transparent( + input: &DeriveInput, + field: &Field, +) -> syn::Result { + check_transparent_attributes(input, field)?; + + let ident = &input.ident; + let ty = &field.ty; + + // extract type generics + let generics = &input.generics; + let (impl_generics, ty_generics, _) = generics.split_for_impl(); + + // add db type for clause + let mut generics = generics.clone(); + generics + .make_where_clause() + .predicates + .push(parse_quote!(Self: sqlx::types::HasSqlType<#ty>)); + let (_, _, where_clause) = generics.split_for_impl(); + + let mut tts = proc_macro2::TokenStream::new(); + + if cfg!(feature = "mysql") { + tts.extend(quote!( + impl #impl_generics sqlx::types::HasSqlType< #ident #ty_generics > for sqlx::MySql #where_clause { + fn type_info() -> Self::TypeInfo { + >::type_info() + } + } + )); + } + + if cfg!(feature = "postgres") { + tts.extend(quote!( + impl #impl_generics sqlx::types::HasSqlType< #ident #ty_generics > for sqlx::Postgres #where_clause { + fn type_info() -> Self::TypeInfo { + >::type_info() + } + } + )); + } + + Ok(tts) +} + +fn expand_derive_has_sql_type_weak_enum( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let repr = check_weak_enum_attributes(input, variants)?; + + let ident = &input.ident; + let mut tts = proc_macro2::TokenStream::new(); + + if cfg!(feature = "mysql") { + tts.extend(quote!( + impl sqlx::types::HasSqlType< #ident > for sqlx::MySql where Self: sqlx::types::HasSqlType< #repr > { + fn type_info() -> Self::TypeInfo { + >::type_info() + } + } + )); + } + + if cfg!(feature = "postgres") { + tts.extend(quote!( + impl sqlx::types::HasSqlType< #ident > for sqlx::Postgres where Self: sqlx::types::HasSqlType< #repr > { + fn type_info() -> Self::TypeInfo { + >::type_info() + } + } + )); + } + + Ok(tts) +} + +fn expand_derive_has_sql_type_strong_enum( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let attributes = check_strong_enum_attributes(input, variants)?; + + let ident = &input.ident; + let mut tts = proc_macro2::TokenStream::new(); + + if cfg!(feature = "mysql") { + tts.extend(quote!( + impl sqlx::types::HasSqlType< #ident > for sqlx::MySql { + fn type_info() -> Self::TypeInfo { + sqlx::mysql::MySqlTypeInfo::r#enum() + } + } + )); + } + + if cfg!(feature = "postgres") { + let oid = attributes.postgres_oid.unwrap(); + tts.extend(quote!( + impl sqlx::types::HasSqlType< #ident > for sqlx::Postgres { + fn type_info() -> Self::TypeInfo { + sqlx::postgres::PgTypeInfo::with_oid(#oid) + } + } + )); + } + + Ok(tts) +} + +fn expand_derive_has_sql_type_struct( + input: &DeriveInput, + fields: &Punctuated, +) -> syn::Result { + let attributes = check_struct_attributes(input, fields)?; + + let ident = &input.ident; + let mut tts = proc_macro2::TokenStream::new(); + + if cfg!(feature = "postgres") { + let oid = attributes.postgres_oid.unwrap(); + tts.extend(quote!( + impl sqlx::types::HasSqlType< #ident > for sqlx::Postgres { + fn type_info() -> Self::TypeInfo { + sqlx::postgres::PgTypeInfo::with_oid(#oid) + } + } + )); + } + + Ok(tts) +} diff --git a/sqlx-macros/src/derives/mod.rs b/sqlx-macros/src/derives/mod.rs new file mode 100644 index 00000000..28d9eee7 --- /dev/null +++ b/sqlx-macros/src/derives/mod.rs @@ -0,0 +1,25 @@ +mod attributes; +mod decode; +mod encode; +mod has_sql_type; + +pub(crate) use decode::expand_derive_decode; +pub(crate) use encode::expand_derive_encode; +pub(crate) use has_sql_type::expand_derive_has_sql_type; + +use std::iter::FromIterator; +use syn::DeriveInput; + +pub(crate) fn expand_derive_type(input: &DeriveInput) -> syn::Result { + let encode_tts = expand_derive_encode(input)?; + let decode_tts = expand_derive_decode(input)?; + let has_sql_type_tts = expand_derive_has_sql_type(input)?; + + let combined = proc_macro2::TokenStream::from_iter( + encode_tts + .into_iter() + .chain(decode_tts) + .chain(has_sql_type_tts), + ); + Ok(combined) +}