diff --git a/sqlx-macros/src/derives/decode.rs b/sqlx-macros/src/derives/decode.rs index e3c5cc78..e37862ff 100644 --- a/sqlx-macros/src/derives/decode.rs +++ b/sqlx-macros/src/derives/decode.rs @@ -59,23 +59,37 @@ fn expand_derive_decode_transparent( let generics = &input.generics; let (_, ty_generics, _) = generics.split_for_impl(); - // add db type for impl generics & where clause - let mut generics = generics.clone(); - generics.params.insert(0, parse_quote!(DB: sqlx::Database)); - generics.params.insert(0, parse_quote!('r)); - generics - .make_where_clause() - .predicates - .push(parse_quote!(#ty: sqlx::decode::Decode<'r, DB>)); - let (impl_generics, _, where_clause) = generics.split_for_impl(); + let mut tts = proc_macro2::TokenStream::new(); - let tts = quote!( - impl #impl_generics sqlx::decode::Decode<'r, DB> for #ident #ty_generics #where_clause { - fn decode(value: >::ValueRef) -> std::result::Result> { - <#ty as sqlx::decode::Decode<'r, DB>>::decode(value).map(Self) + if cfg!(feature = "mysql") { + tts.extend(quote!( + impl<'r> sqlx::decode::Decode<'r, sqlx::MySql> for #ident #ty_generics where #ty: sqlx::decode::Decode<'r, sqlx::MySql> { + fn decode(value: >::ValueRef) -> std::result::Result> { + <#ty as sqlx::decode::Decode<'r, sqlx::MySql>>::decode(value).map(Self) + } } - } - ); + )); + } + + if cfg!(feature = "postgres") { + tts.extend(quote!( + impl<'r> sqlx::decode::Decode<'r, sqlx::Postgres> for #ident #ty_generics where #ty: sqlx::decode::Decode<'r, sqlx::Postgres> { + fn decode(value: >::ValueRef) -> std::result::Result> { + <#ty as sqlx::decode::Decode<'r, sqlx::Postgres>>::decode(value).map(Self) + } + } + )); + } + + if cfg!(feature = "sqlite") { + tts.extend(quote!( + impl<'r> sqlx::decode::Decode<'r, sqlx::Sqlite> for #ident #ty_generics where #ty: sqlx::decode::Decode<'r, sqlx::Sqlite> { + fn decode(value: >::ValueRef) -> std::result::Result> { + <#ty as sqlx::decode::Decode<'r, sqlx::Sqlite>>::decode(value).map(Self) + } + } + )); + } Ok(tts) } @@ -98,19 +112,57 @@ fn expand_derive_decode_weak_enum( }) .collect::>(); - Ok(quote!( - impl<'r, DB: sqlx::Database> sqlx::decode::Decode<'r, DB> for #ident where #repr: sqlx::decode::Decode<'r, DB> { - fn decode(value: >::ValueRef) -> std::result::Result> { - let value = <#repr as sqlx::decode::Decode<'r, DB>>::decode(value)?; + let mut tts = proc_macro2::TokenStream::new(); - match value { - #(#arms)* + if cfg!(feature = "mysql") { + tts.extend(quote!( + impl<'r> sqlx::decode::Decode<'r, sqlx::MySql> for #ident where #repr: sqlx::decode::Decode<'r, sqlx::MySql> { + fn decode(value: >::ValueRef) -> std::result::Result> { + let value = <#repr as sqlx::decode::Decode<'r, sqlx::MySql>>::decode(value)?; - _ => Err(Box::new(sqlx::Error::Decode(format!("invalid value {:?} for enum {}", value, #ident_s).into()))) + match value { + #(#arms)* + + _ => Err(Box::new(sqlx::Error::Decode(format!("invalid value {:?} for enum {}", value, #ident_s).into()))) + } } } - } - )) + )); + } + + if cfg!(feature = "postgres") { + tts.extend(quote!( + impl<'r> sqlx::decode::Decode<'r, sqlx::Postgres> for #ident where #repr: sqlx::decode::Decode<'r, sqlx::Postgres> { + fn decode(value: >::ValueRef) -> std::result::Result> { + let value = <#repr as sqlx::decode::Decode<'r, sqlx::Postgres>>::decode(value)?; + + match value { + #(#arms)* + + _ => Err(Box::new(sqlx::Error::Decode(format!("invalid value {:?} for enum {}", value, #ident_s).into()))) + } + } + } + )); + } + + if cfg!(feature = "sqlite") { + tts.extend(quote!( + impl<'r> sqlx::decode::Decode<'r, sqlx::Sqlite> for #ident where #repr: sqlx::decode::Decode<'r, sqlx::Sqlite> { + fn decode(value: >::ValueRef) -> std::result::Result> { + let value = <#repr as sqlx::decode::Decode<'r, sqlx::Sqlite>>::decode(value)?; + + match value { + #(#arms)* + + _ => Err(Box::new(sqlx::Error::Decode(format!("invalid value {:?} for enum {}", value, #ident_s).into()))) + } + } + } + )); + } + + Ok(tts) } fn expand_derive_decode_strong_enum( diff --git a/sqlx-macros/src/derives/encode.rs b/sqlx-macros/src/derives/encode.rs index 1560b07c..d8b64fea 100644 --- a/sqlx-macros/src/derives/encode.rs +++ b/sqlx-macros/src/derives/encode.rs @@ -68,28 +68,70 @@ fn expand_derive_encode_transparent( .params .insert(0, LifetimeDef::new(lifetime.clone()).into()); - generics.params.insert(0, parse_quote!(DB: sqlx::Database)); generics .make_where_clause() .predicates .push(parse_quote!(#ty: sqlx::encode::Encode<#lifetime, DB>)); - let (impl_generics, _, where_clause) = generics.split_for_impl(); - Ok(quote!( - impl #impl_generics sqlx::encode::Encode<#lifetime, DB> for #ident #ty_generics #where_clause { - fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> sqlx::encode::IsNull { - <#ty as sqlx::encode::Encode<#lifetime, DB>>::encode_by_ref(&self.0, buf) - } + let (impl_generics, _, _) = generics.split_for_impl(); - fn produces(&self) -> Option { - <#ty as sqlx::encode::Encode<#lifetime, DB>>::produces(&self.0) - } + let mut tts = proc_macro2::TokenStream::new(); - fn size_hint(&self) -> usize { - <#ty as sqlx::encode::Encode<#lifetime, DB>>::size_hint(&self.0) + if cfg!(feature = "mysql") { + tts.extend(quote!( + impl #impl_generics sqlx::encode::Encode<#lifetime, sqlx::MySql> for #ident #ty_generics where #ty: sqlx::encode::Encode<#lifetime, sqlx::MySql> { + fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> sqlx::encode::IsNull { + <#ty as sqlx::encode::Encode<#lifetime, sqlx::MySql>>::encode_by_ref(&self.0, buf) + } + + fn produces(&self) -> Option { + <#ty as sqlx::encode::Encode<#lifetime, sqlx::MySql>>::produces(&self.0) + } + + fn size_hint(&self) -> usize { + <#ty as sqlx::encode::Encode<#lifetime, sqlx::MySql>>::size_hint(&self.0) + } } - } - )) + )); + } + + if cfg!(feature = "postgres") { + tts.extend(quote!( + impl #impl_generics sqlx::encode::Encode<#lifetime, sqlx::Postgres> for #ident #ty_generics where #ty: sqlx::encode::Encode<#lifetime, sqlx::Postgres> { + fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> sqlx::encode::IsNull { + <#ty as sqlx::encode::Encode<#lifetime, sqlx::Postgres>>::encode_by_ref(&self.0, buf) + } + + fn produces(&self) -> Option { + <#ty as sqlx::encode::Encode<#lifetime, sqlx::Postgres>>::produces(&self.0) + } + + fn size_hint(&self) -> usize { + <#ty as sqlx::encode::Encode<#lifetime, sqlx::Postgres>>::size_hint(&self.0) + } + } + )); + } + + if cfg!(feature = "sqlite") { + tts.extend(quote!( + impl #impl_generics sqlx::encode::Encode<#lifetime, sqlx::Sqlite> for #ident #ty_generics where #ty: sqlx::encode::Encode<#lifetime, sqlx::Sqlite> { + fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> sqlx::encode::IsNull { + <#ty as sqlx::encode::Encode<#lifetime, sqlx::Sqlite>>::encode_by_ref(&self.0, buf) + } + + fn produces(&self) -> Option { + <#ty as sqlx::encode::Encode<#lifetime, sqlx::Sqlite>>::produces(&self.0) + } + + fn size_hint(&self) -> usize { + <#ty as sqlx::encode::Encode<#lifetime, sqlx::Sqlite>>::size_hint(&self.0) + } + } + )); + } + + Ok(tts) } fn expand_derive_encode_weak_enum( @@ -101,21 +143,63 @@ fn expand_derive_encode_weak_enum( let ident = &input.ident; - Ok(quote!( - impl<'q, DB: sqlx::Database> sqlx::encode::Encode<'q, DB> for #ident where #repr: sqlx::encode::Encode<'q, DB> { - fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> sqlx::encode::IsNull { - <#repr as sqlx::encode::Encode>::encode_by_ref(&(*self as #repr), buf) - } + let mut tts = proc_macro2::TokenStream::new(); - fn produces(&self) -> Option { - <#repr as sqlx::encode::Encode>::produces(&(*self as #repr)) - } + if cfg!(feature = "mysql") { + tts.extend(quote!( + impl<'q> sqlx::encode::Encode<'q, sqlx::MySql> for #ident where #repr: sqlx::encode::Encode<'q, sqlx::MySql> { + fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> sqlx::encode::IsNull { + <#repr as sqlx::encode::Encode>::encode_by_ref(&(*self as #repr), buf) + } - fn size_hint(&self) -> usize { - <#repr as sqlx::encode::Encode>::size_hint(&(*self as #repr)) + fn produces(&self) -> Option { + <#repr as sqlx::encode::Encode>::produces(&(*self as #repr)) + } + + fn size_hint(&self) -> usize { + <#repr as sqlx::encode::Encode>::size_hint(&(*self as #repr)) + } } - } - )) + )); + } + + if cfg!(feature = "postgres") { + tts.extend(quote!( + impl<'q> sqlx::encode::Encode<'q, sqlx::Postgres> for #ident where #repr: sqlx::encode::Encode<'q, sqlx::Postgres> { + fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> sqlx::encode::IsNull { + <#repr as sqlx::encode::Encode>::encode_by_ref(&(*self as #repr), buf) + } + + fn produces(&self) -> Option { + <#repr as sqlx::encode::Encode>::produces(&(*self as #repr)) + } + + fn size_hint(&self) -> usize { + <#repr as sqlx::encode::Encode>::size_hint(&(*self as #repr)) + } + } + )); + } + + if cfg!(feature = "sqlite") { + tts.extend(quote!( + impl<'q> sqlx::encode::Encode<'q, sqlx::Sqlite> for #ident where #repr: sqlx::encode::Encode<'q, sqlx::Sqlite> { + fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> sqlx::encode::IsNull { + <#repr as sqlx::encode::Encode>::encode_by_ref(&(*self as #repr), buf) + } + + fn produces(&self) -> Option { + <#repr as sqlx::encode::Encode>::produces(&(*self as #repr)) + } + + fn size_hint(&self) -> usize { + <#repr as sqlx::encode::Encode>::size_hint(&(*self as #repr)) + } + } + )); + } + + Ok(tts) } fn expand_derive_encode_strong_enum( @@ -143,25 +227,75 @@ fn expand_derive_encode_strong_enum( } } - Ok(quote!( - impl<'q, DB: sqlx::Database> sqlx::encode::Encode<'q, DB> for #ident where &'q str: sqlx::encode::Encode<'q, DB> { - fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> sqlx::encode::IsNull { - let val = match self { - #(#value_arms)* - }; + let mut tts = proc_macro2::TokenStream::new(); - <&str as sqlx::encode::Encode<'q, DB>>::encode(val, buf) + if cfg!(feature = "mysql") { + tts.extend(quote!( + impl<'q> sqlx::encode::Encode<'q, sqlx::MySql> for #ident where &'q str: sqlx::encode::Encode<'q, sqlx::MySql> { + fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> sqlx::encode::IsNull { + let val = match self { + #(#value_arms)* + }; + + <&str as sqlx::encode::Encode<'q, sqlx::MySql>>::encode(val, buf) + } + + fn size_hint(&self) -> usize { + let val = match self { + #(#value_arms)* + }; + + <&str as sqlx::encode::Encode<'q, sqlx::MySql>>::size_hint(&val) + } } + )); + } - fn size_hint(&self) -> usize { - let val = match self { - #(#value_arms)* - }; + if cfg!(feature = "postgres") { + tts.extend(quote!( + impl<'q> sqlx::encode::Encode<'q, sqlx::Postgres> for #ident where &'q str: sqlx::encode::Encode<'q, sqlx::Postgres> { + fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> sqlx::encode::IsNull { + let val = match self { + #(#value_arms)* + }; - <&str as sqlx::encode::Encode<'q, DB>>::size_hint(&val) + <&str as sqlx::encode::Encode<'q, sqlx::Postgres>>::encode(val, buf) + } + + fn size_hint(&self) -> usize { + let val = match self { + #(#value_arms)* + }; + + <&str as sqlx::encode::Encode<'q, sqlx::Postgres>>::size_hint(&val) + } } - } - )) + )); + } + + if cfg!(feature = "sqlite") { + tts.extend(quote!( + impl<'q> sqlx::encode::Encode<'q, sqlx::Sqlite> for #ident where &'q str: sqlx::encode::Encode<'q, sqlx::Sqlite> { + fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> sqlx::encode::IsNull { + let val = match self { + #(#value_arms)* + }; + + <&str as sqlx::encode::Encode<'q, sqlx::Sqlite>>::encode(val, buf) + } + + fn size_hint(&self) -> usize { + let val = match self { + #(#value_arms)* + }; + + <&str as sqlx::encode::Encode<'q, sqlx::Sqlite>>::size_hint(&val) + } + } + )); + } + + Ok(tts) } fn expand_derive_encode_struct( diff --git a/sqlx-macros/src/derives/type.rs b/sqlx-macros/src/derives/type.rs index ad41dba5..b928c4ee 100644 --- a/sqlx-macros/src/derives/type.rs +++ b/sqlx-macros/src/derives/type.rs @@ -59,21 +59,64 @@ fn expand_derive_has_sql_type_transparent( if attr.transparent { let mut generics = generics.clone(); - generics.params.insert(0, parse_quote!(DB: sqlx::Database)); - generics - .make_where_clause() - .predicates - .push(parse_quote!(#ty: sqlx::Type)); - let (impl_generics, _, where_clause) = generics.split_for_impl(); + let mut tts = proc_macro2::TokenStream::new(); - return Ok(quote!( - impl #impl_generics sqlx::Type< DB > for #ident #ty_generics #where_clause { - fn type_info() -> DB::TypeInfo { - <#ty as sqlx::Type>::type_info() + if cfg!(feature = "mysql") { + generics + .make_where_clause() + .predicates + .push(parse_quote!(#ty: sqlx::Type)); + + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + tts.extend(quote!( + impl #impl_generics sqlx::Type for #ident #ty_generics #where_clause + { + fn type_info() -> sqlx::mysql::MySqlTypeInfo { + <#ty as sqlx::Type>::type_info() + } } - } - )); + )); + } + + if cfg!(feature = "postgres") { + generics + .make_where_clause() + .predicates + .push(parse_quote!(#ty: sqlx::Type)); + + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + tts.extend(quote!( + impl #impl_generics sqlx::Type for #ident #ty_generics #where_clause + { + fn type_info() -> sqlx::postgres::PgTypeInfo { + <#ty as sqlx::Type>::type_info() + } + } + )); + } + + if cfg!(feature = "sqlite") { + generics + .make_where_clause() + .predicates + .push(parse_quote!(#ty: sqlx::Type)); + + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + tts.extend(quote!( + impl #impl_generics sqlx::Type for #ident #ty_generics #where_clause + { + fn type_info() -> sqlx::sqlite::SqliteTypeInfo { + <#ty as sqlx::Type>::type_info() + } + } + )); + } + + return Ok(tts); } let mut tts = proc_macro2::TokenStream::new(); @@ -100,18 +143,49 @@ fn expand_derive_has_sql_type_weak_enum( let attr = check_weak_enum_attributes(input, variants)?; let repr = attr.repr.unwrap(); let ident = &input.ident; - let ts = quote!( - impl sqlx::Type for #ident - where - #repr: sqlx::Type, - { - fn type_info() -> DB::TypeInfo { - <#repr as sqlx::Type>::type_info() - } - } - ); - Ok(ts) + let mut tts = proc_macro2::TokenStream::new(); + + if cfg!(feature = "mysql") { + tts.extend(quote!( + impl sqlx::Type for #ident + where + #repr: sqlx::Type, + { + fn type_info() -> sqlx::mysql::MySqlTypeInfo { + <#repr as sqlx::Type>::type_info() + } + } + )); + } + + if cfg!(feature = "postgres") { + tts.extend(quote!( + impl sqlx::Type for #ident + where + #repr: sqlx::Type, + { + fn type_info() -> sqlx::postgres::PgTypeInfo { + <#repr as sqlx::Type>::type_info() + } + } + )); + } + + if cfg!(feature = "sqlite") { + tts.extend(quote!( + impl sqlx::Type for #ident + where + #repr: sqlx::Type, + { + fn type_info() -> sqlx::sqlite::SqliteTypeInfo { + <#repr as sqlx::Type>::type_info() + } + } + )); + } + + Ok(tts) } fn expand_derive_has_sql_type_strong_enum(