diff --git a/sqlx-macros/src/derives/decode.rs b/sqlx-macros/src/derives/decode.rs index e37862ff2..e3c5cc783 100644 --- a/sqlx-macros/src/derives/decode.rs +++ b/sqlx-macros/src/derives/decode.rs @@ -59,37 +59,23 @@ fn expand_derive_decode_transparent( let generics = &input.generics; let (_, ty_generics, _) = generics.split_for_impl(); - let mut tts = proc_macro2::TokenStream::new(); + // add db type for impl generics & where clause + let mut generics = generics.clone(); + generics.params.insert(0, parse_quote!(DB: sqlx::Database)); + generics.params.insert(0, parse_quote!('r)); + generics + .make_where_clause() + .predicates + .push(parse_quote!(#ty: sqlx::decode::Decode<'r, DB>)); + let (impl_generics, _, where_clause) = generics.split_for_impl(); - if cfg!(feature = "mysql") { - tts.extend(quote!( - impl<'r> sqlx::decode::Decode<'r, sqlx::MySql> for #ident #ty_generics where #ty: sqlx::decode::Decode<'r, sqlx::MySql> { - fn decode(value: >::ValueRef) -> std::result::Result> { - <#ty as sqlx::decode::Decode<'r, sqlx::MySql>>::decode(value).map(Self) - } + let tts = quote!( + impl #impl_generics sqlx::decode::Decode<'r, DB> for #ident #ty_generics #where_clause { + fn decode(value: >::ValueRef) -> std::result::Result> { + <#ty as sqlx::decode::Decode<'r, DB>>::decode(value).map(Self) } - )); - } - - if cfg!(feature = "postgres") { - tts.extend(quote!( - impl<'r> sqlx::decode::Decode<'r, sqlx::Postgres> for #ident #ty_generics where #ty: sqlx::decode::Decode<'r, sqlx::Postgres> { - fn decode(value: >::ValueRef) -> std::result::Result> { - <#ty as sqlx::decode::Decode<'r, sqlx::Postgres>>::decode(value).map(Self) - } - } - )); - } - - if cfg!(feature = "sqlite") { - tts.extend(quote!( - impl<'r> sqlx::decode::Decode<'r, sqlx::Sqlite> for #ident #ty_generics where #ty: sqlx::decode::Decode<'r, sqlx::Sqlite> { - fn decode(value: >::ValueRef) -> std::result::Result> { - <#ty as sqlx::decode::Decode<'r, sqlx::Sqlite>>::decode(value).map(Self) - } - } - )); - } + } + ); Ok(tts) } @@ -112,57 +98,19 @@ fn expand_derive_decode_weak_enum( }) .collect::>(); - let mut tts = proc_macro2::TokenStream::new(); + Ok(quote!( + impl<'r, DB: sqlx::Database> sqlx::decode::Decode<'r, DB> for #ident where #repr: sqlx::decode::Decode<'r, DB> { + fn decode(value: >::ValueRef) -> std::result::Result> { + let value = <#repr as sqlx::decode::Decode<'r, DB>>::decode(value)?; - if cfg!(feature = "mysql") { - tts.extend(quote!( - impl<'r> sqlx::decode::Decode<'r, sqlx::MySql> for #ident where #repr: sqlx::decode::Decode<'r, sqlx::MySql> { - fn decode(value: >::ValueRef) -> std::result::Result> { - let value = <#repr as sqlx::decode::Decode<'r, sqlx::MySql>>::decode(value)?; + match value { + #(#arms)* - match value { - #(#arms)* - - _ => Err(Box::new(sqlx::Error::Decode(format!("invalid value {:?} for enum {}", value, #ident_s).into()))) - } + _ => Err(Box::new(sqlx::Error::Decode(format!("invalid value {:?} for enum {}", value, #ident_s).into()))) } } - )); - } - - if cfg!(feature = "postgres") { - tts.extend(quote!( - impl<'r> sqlx::decode::Decode<'r, sqlx::Postgres> for #ident where #repr: sqlx::decode::Decode<'r, sqlx::Postgres> { - fn decode(value: >::ValueRef) -> std::result::Result> { - let value = <#repr as sqlx::decode::Decode<'r, sqlx::Postgres>>::decode(value)?; - - match value { - #(#arms)* - - _ => Err(Box::new(sqlx::Error::Decode(format!("invalid value {:?} for enum {}", value, #ident_s).into()))) - } - } - } - )); - } - - if cfg!(feature = "sqlite") { - tts.extend(quote!( - impl<'r> sqlx::decode::Decode<'r, sqlx::Sqlite> for #ident where #repr: sqlx::decode::Decode<'r, sqlx::Sqlite> { - fn decode(value: >::ValueRef) -> std::result::Result> { - let value = <#repr as sqlx::decode::Decode<'r, sqlx::Sqlite>>::decode(value)?; - - match value { - #(#arms)* - - _ => Err(Box::new(sqlx::Error::Decode(format!("invalid value {:?} for enum {}", value, #ident_s).into()))) - } - } - } - )); - } - - Ok(tts) + } + )) } fn expand_derive_decode_strong_enum( diff --git a/sqlx-macros/src/derives/encode.rs b/sqlx-macros/src/derives/encode.rs index d8b64fea1..1560b07c2 100644 --- a/sqlx-macros/src/derives/encode.rs +++ b/sqlx-macros/src/derives/encode.rs @@ -68,70 +68,28 @@ fn expand_derive_encode_transparent( .params .insert(0, LifetimeDef::new(lifetime.clone()).into()); + generics.params.insert(0, parse_quote!(DB: sqlx::Database)); generics .make_where_clause() .predicates .push(parse_quote!(#ty: sqlx::encode::Encode<#lifetime, DB>)); + let (impl_generics, _, where_clause) = generics.split_for_impl(); - let (impl_generics, _, _) = generics.split_for_impl(); - - let mut tts = proc_macro2::TokenStream::new(); - - if cfg!(feature = "mysql") { - tts.extend(quote!( - impl #impl_generics sqlx::encode::Encode<#lifetime, sqlx::MySql> for #ident #ty_generics where #ty: sqlx::encode::Encode<#lifetime, sqlx::MySql> { - fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> sqlx::encode::IsNull { - <#ty as sqlx::encode::Encode<#lifetime, sqlx::MySql>>::encode_by_ref(&self.0, buf) - } - - fn produces(&self) -> Option { - <#ty as sqlx::encode::Encode<#lifetime, sqlx::MySql>>::produces(&self.0) - } - - fn size_hint(&self) -> usize { - <#ty as sqlx::encode::Encode<#lifetime, sqlx::MySql>>::size_hint(&self.0) - } + Ok(quote!( + impl #impl_generics sqlx::encode::Encode<#lifetime, DB> for #ident #ty_generics #where_clause { + fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> sqlx::encode::IsNull { + <#ty as sqlx::encode::Encode<#lifetime, DB>>::encode_by_ref(&self.0, buf) } - )); - } - if cfg!(feature = "postgres") { - tts.extend(quote!( - impl #impl_generics sqlx::encode::Encode<#lifetime, sqlx::Postgres> for #ident #ty_generics where #ty: sqlx::encode::Encode<#lifetime, sqlx::Postgres> { - fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> sqlx::encode::IsNull { - <#ty as sqlx::encode::Encode<#lifetime, sqlx::Postgres>>::encode_by_ref(&self.0, buf) - } - - fn produces(&self) -> Option { - <#ty as sqlx::encode::Encode<#lifetime, sqlx::Postgres>>::produces(&self.0) - } - - fn size_hint(&self) -> usize { - <#ty as sqlx::encode::Encode<#lifetime, sqlx::Postgres>>::size_hint(&self.0) - } + fn produces(&self) -> Option { + <#ty as sqlx::encode::Encode<#lifetime, DB>>::produces(&self.0) } - )); - } - if cfg!(feature = "sqlite") { - tts.extend(quote!( - impl #impl_generics sqlx::encode::Encode<#lifetime, sqlx::Sqlite> for #ident #ty_generics where #ty: sqlx::encode::Encode<#lifetime, sqlx::Sqlite> { - fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> sqlx::encode::IsNull { - <#ty as sqlx::encode::Encode<#lifetime, sqlx::Sqlite>>::encode_by_ref(&self.0, buf) - } - - fn produces(&self) -> Option { - <#ty as sqlx::encode::Encode<#lifetime, sqlx::Sqlite>>::produces(&self.0) - } - - fn size_hint(&self) -> usize { - <#ty as sqlx::encode::Encode<#lifetime, sqlx::Sqlite>>::size_hint(&self.0) - } + fn size_hint(&self) -> usize { + <#ty as sqlx::encode::Encode<#lifetime, DB>>::size_hint(&self.0) } - )); - } - - Ok(tts) + } + )) } fn expand_derive_encode_weak_enum( @@ -143,63 +101,21 @@ fn expand_derive_encode_weak_enum( let ident = &input.ident; - let mut tts = proc_macro2::TokenStream::new(); - - if cfg!(feature = "mysql") { - tts.extend(quote!( - impl<'q> sqlx::encode::Encode<'q, sqlx::MySql> for #ident where #repr: sqlx::encode::Encode<'q, sqlx::MySql> { - fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> sqlx::encode::IsNull { - <#repr as sqlx::encode::Encode>::encode_by_ref(&(*self as #repr), buf) - } - - fn produces(&self) -> Option { - <#repr as sqlx::encode::Encode>::produces(&(*self as #repr)) - } - - fn size_hint(&self) -> usize { - <#repr as sqlx::encode::Encode>::size_hint(&(*self as #repr)) - } + Ok(quote!( + impl<'q, DB: sqlx::Database> sqlx::encode::Encode<'q, DB> for #ident where #repr: sqlx::encode::Encode<'q, DB> { + fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> sqlx::encode::IsNull { + <#repr as sqlx::encode::Encode>::encode_by_ref(&(*self as #repr), buf) } - )); - } - if cfg!(feature = "postgres") { - tts.extend(quote!( - impl<'q> sqlx::encode::Encode<'q, sqlx::Postgres> for #ident where #repr: sqlx::encode::Encode<'q, sqlx::Postgres> { - fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> sqlx::encode::IsNull { - <#repr as sqlx::encode::Encode>::encode_by_ref(&(*self as #repr), buf) - } - - fn produces(&self) -> Option { - <#repr as sqlx::encode::Encode>::produces(&(*self as #repr)) - } - - fn size_hint(&self) -> usize { - <#repr as sqlx::encode::Encode>::size_hint(&(*self as #repr)) - } + fn produces(&self) -> Option { + <#repr as sqlx::encode::Encode>::produces(&(*self as #repr)) } - )); - } - if cfg!(feature = "sqlite") { - tts.extend(quote!( - impl<'q> sqlx::encode::Encode<'q, sqlx::Sqlite> for #ident where #repr: sqlx::encode::Encode<'q, sqlx::Sqlite> { - fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> sqlx::encode::IsNull { - <#repr as sqlx::encode::Encode>::encode_by_ref(&(*self as #repr), buf) - } - - fn produces(&self) -> Option { - <#repr as sqlx::encode::Encode>::produces(&(*self as #repr)) - } - - fn size_hint(&self) -> usize { - <#repr as sqlx::encode::Encode>::size_hint(&(*self as #repr)) - } + fn size_hint(&self) -> usize { + <#repr as sqlx::encode::Encode>::size_hint(&(*self as #repr)) } - )); - } - - Ok(tts) + } + )) } fn expand_derive_encode_strong_enum( @@ -227,75 +143,25 @@ fn expand_derive_encode_strong_enum( } } - let mut tts = proc_macro2::TokenStream::new(); + Ok(quote!( + impl<'q, DB: sqlx::Database> sqlx::encode::Encode<'q, DB> for #ident where &'q str: sqlx::encode::Encode<'q, DB> { + fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> sqlx::encode::IsNull { + let val = match self { + #(#value_arms)* + }; - if cfg!(feature = "mysql") { - tts.extend(quote!( - impl<'q> sqlx::encode::Encode<'q, sqlx::MySql> for #ident where &'q str: sqlx::encode::Encode<'q, sqlx::MySql> { - fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> sqlx::encode::IsNull { - let val = match self { - #(#value_arms)* - }; - - <&str as sqlx::encode::Encode<'q, sqlx::MySql>>::encode(val, buf) - } - - fn size_hint(&self) -> usize { - let val = match self { - #(#value_arms)* - }; - - <&str as sqlx::encode::Encode<'q, sqlx::MySql>>::size_hint(&val) - } + <&str as sqlx::encode::Encode<'q, DB>>::encode(val, buf) } - )); - } - if cfg!(feature = "postgres") { - tts.extend(quote!( - impl<'q> sqlx::encode::Encode<'q, sqlx::Postgres> for #ident where &'q str: sqlx::encode::Encode<'q, sqlx::Postgres> { - fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> sqlx::encode::IsNull { - let val = match self { - #(#value_arms)* - }; + fn size_hint(&self) -> usize { + let val = match self { + #(#value_arms)* + }; - <&str as sqlx::encode::Encode<'q, sqlx::Postgres>>::encode(val, buf) - } - - fn size_hint(&self) -> usize { - let val = match self { - #(#value_arms)* - }; - - <&str as sqlx::encode::Encode<'q, sqlx::Postgres>>::size_hint(&val) - } + <&str as sqlx::encode::Encode<'q, DB>>::size_hint(&val) } - )); - } - - if cfg!(feature = "sqlite") { - tts.extend(quote!( - impl<'q> sqlx::encode::Encode<'q, sqlx::Sqlite> for #ident where &'q str: sqlx::encode::Encode<'q, sqlx::Sqlite> { - fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> sqlx::encode::IsNull { - let val = match self { - #(#value_arms)* - }; - - <&str as sqlx::encode::Encode<'q, sqlx::Sqlite>>::encode(val, buf) - } - - fn size_hint(&self) -> usize { - let val = match self { - #(#value_arms)* - }; - - <&str as sqlx::encode::Encode<'q, sqlx::Sqlite>>::size_hint(&val) - } - } - )); - } - - Ok(tts) + } + )) } fn expand_derive_encode_struct( diff --git a/sqlx-macros/src/derives/type.rs b/sqlx-macros/src/derives/type.rs index b928c4eec..ad41dba58 100644 --- a/sqlx-macros/src/derives/type.rs +++ b/sqlx-macros/src/derives/type.rs @@ -59,64 +59,21 @@ fn expand_derive_has_sql_type_transparent( if attr.transparent { let mut generics = generics.clone(); + generics.params.insert(0, parse_quote!(DB: sqlx::Database)); + generics + .make_where_clause() + .predicates + .push(parse_quote!(#ty: sqlx::Type)); - let mut tts = proc_macro2::TokenStream::new(); + let (impl_generics, _, where_clause) = generics.split_for_impl(); - if cfg!(feature = "mysql") { - generics - .make_where_clause() - .predicates - .push(parse_quote!(#ty: sqlx::Type)); - - let (impl_generics, _, where_clause) = generics.split_for_impl(); - - tts.extend(quote!( - impl #impl_generics sqlx::Type for #ident #ty_generics #where_clause - { - fn type_info() -> sqlx::mysql::MySqlTypeInfo { - <#ty as sqlx::Type>::type_info() - } + return Ok(quote!( + impl #impl_generics sqlx::Type< DB > for #ident #ty_generics #where_clause { + fn type_info() -> DB::TypeInfo { + <#ty as sqlx::Type>::type_info() } - )); - } - - if cfg!(feature = "postgres") { - generics - .make_where_clause() - .predicates - .push(parse_quote!(#ty: sqlx::Type)); - - let (impl_generics, _, where_clause) = generics.split_for_impl(); - - tts.extend(quote!( - impl #impl_generics sqlx::Type for #ident #ty_generics #where_clause - { - fn type_info() -> sqlx::postgres::PgTypeInfo { - <#ty as sqlx::Type>::type_info() - } - } - )); - } - - if cfg!(feature = "sqlite") { - generics - .make_where_clause() - .predicates - .push(parse_quote!(#ty: sqlx::Type)); - - let (impl_generics, _, where_clause) = generics.split_for_impl(); - - tts.extend(quote!( - impl #impl_generics sqlx::Type for #ident #ty_generics #where_clause - { - fn type_info() -> sqlx::sqlite::SqliteTypeInfo { - <#ty as sqlx::Type>::type_info() - } - } - )); - } - - return Ok(tts); + } + )); } let mut tts = proc_macro2::TokenStream::new(); @@ -143,49 +100,18 @@ fn expand_derive_has_sql_type_weak_enum( let attr = check_weak_enum_attributes(input, variants)?; let repr = attr.repr.unwrap(); let ident = &input.ident; - - let mut tts = proc_macro2::TokenStream::new(); - - if cfg!(feature = "mysql") { - tts.extend(quote!( - impl sqlx::Type for #ident - where - #repr: sqlx::Type, - { - fn type_info() -> sqlx::mysql::MySqlTypeInfo { - <#repr as sqlx::Type>::type_info() - } + let ts = quote!( + impl sqlx::Type for #ident + where + #repr: sqlx::Type, + { + fn type_info() -> DB::TypeInfo { + <#repr as sqlx::Type>::type_info() } - )); - } + } + ); - if cfg!(feature = "postgres") { - tts.extend(quote!( - impl sqlx::Type for #ident - where - #repr: sqlx::Type, - { - fn type_info() -> sqlx::postgres::PgTypeInfo { - <#repr as sqlx::Type>::type_info() - } - } - )); - } - - if cfg!(feature = "sqlite") { - tts.extend(quote!( - impl sqlx::Type for #ident - where - #repr: sqlx::Type, - { - fn type_info() -> sqlx::sqlite::SqliteTypeInfo { - <#repr as sqlx::Type>::type_info() - } - } - )); - } - - Ok(tts) + Ok(ts) } fn expand_derive_has_sql_type_strong_enum(