refactor(derives): use separate impls per database

database-generic impls are *mostly* impossible in SQLx so we recently
capitalized on that and made it *totally* impossible (until Rust
has specialization and lazy norm)
This commit is contained in:
Ryan Leckey 2020-06-27 05:30:38 -07:00
parent af7bd71ab2
commit e3483230e0
3 changed files with 347 additions and 87 deletions

View File

@ -59,23 +59,37 @@ fn expand_derive_decode_transparent(
let generics = &input.generics; let generics = &input.generics;
let (_, ty_generics, _) = generics.split_for_impl(); let (_, ty_generics, _) = generics.split_for_impl();
// add db type for impl generics & where clause let mut tts = proc_macro2::TokenStream::new();
let mut generics = generics.clone();
generics.params.insert(0, parse_quote!(DB: sqlx::Database));
generics.params.insert(0, parse_quote!('r));
generics
.make_where_clause()
.predicates
.push(parse_quote!(#ty: sqlx::decode::Decode<'r, DB>));
let (impl_generics, _, where_clause) = generics.split_for_impl();
let tts = quote!( if cfg!(feature = "mysql") {
impl #impl_generics sqlx::decode::Decode<'r, DB> for #ident #ty_generics #where_clause { tts.extend(quote!(
fn decode(value: <DB as sqlx::database::HasValueRef<'r>>::ValueRef) -> std::result::Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> { impl<'r> sqlx::decode::Decode<'r, sqlx::MySql> for #ident #ty_generics where #ty: sqlx::decode::Decode<'r, sqlx::MySql> {
<#ty as sqlx::decode::Decode<'r, DB>>::decode(value).map(Self) fn decode(value: <sqlx::MySql as sqlx::database::HasValueRef<'r>>::ValueRef) -> std::result::Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
<#ty as sqlx::decode::Decode<'r, sqlx::MySql>>::decode(value).map(Self)
} }
} }
); ));
}
if cfg!(feature = "postgres") {
tts.extend(quote!(
impl<'r> sqlx::decode::Decode<'r, sqlx::Postgres> for #ident #ty_generics where #ty: sqlx::decode::Decode<'r, sqlx::Postgres> {
fn decode(value: <sqlx::Postgres as sqlx::database::HasValueRef<'r>>::ValueRef) -> std::result::Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
<#ty as sqlx::decode::Decode<'r, sqlx::Postgres>>::decode(value).map(Self)
}
}
));
}
if cfg!(feature = "sqlite") {
tts.extend(quote!(
impl<'r> sqlx::decode::Decode<'r, sqlx::Sqlite> for #ident #ty_generics where #ty: sqlx::decode::Decode<'r, sqlx::Sqlite> {
fn decode(value: <sqlx::Sqlite as sqlx::database::HasValueRef<'r>>::ValueRef) -> std::result::Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
<#ty as sqlx::decode::Decode<'r, sqlx::Sqlite>>::decode(value).map(Self)
}
}
));
}
Ok(tts) Ok(tts)
} }
@ -98,10 +112,13 @@ fn expand_derive_decode_weak_enum(
}) })
.collect::<Vec<Arm>>(); .collect::<Vec<Arm>>();
Ok(quote!( let mut tts = proc_macro2::TokenStream::new();
impl<'r, DB: sqlx::Database> sqlx::decode::Decode<'r, DB> for #ident where #repr: sqlx::decode::Decode<'r, DB> {
fn decode(value: <DB as sqlx::database::HasValueRef<'r>>::ValueRef) -> std::result::Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> { if cfg!(feature = "mysql") {
let value = <#repr as sqlx::decode::Decode<'r, DB>>::decode(value)?; tts.extend(quote!(
impl<'r> sqlx::decode::Decode<'r, sqlx::MySql> for #ident where #repr: sqlx::decode::Decode<'r, sqlx::MySql> {
fn decode(value: <sqlx::MySql as sqlx::database::HasValueRef<'r>>::ValueRef) -> std::result::Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
let value = <#repr as sqlx::decode::Decode<'r, sqlx::MySql>>::decode(value)?;
match value { match value {
#(#arms)* #(#arms)*
@ -110,7 +127,42 @@ fn expand_derive_decode_weak_enum(
} }
} }
} }
)) ));
}
if cfg!(feature = "postgres") {
tts.extend(quote!(
impl<'r> sqlx::decode::Decode<'r, sqlx::Postgres> for #ident where #repr: sqlx::decode::Decode<'r, sqlx::Postgres> {
fn decode(value: <sqlx::Postgres as sqlx::database::HasValueRef<'r>>::ValueRef) -> std::result::Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
let value = <#repr as sqlx::decode::Decode<'r, sqlx::Postgres>>::decode(value)?;
match value {
#(#arms)*
_ => Err(Box::new(sqlx::Error::Decode(format!("invalid value {:?} for enum {}", value, #ident_s).into())))
}
}
}
));
}
if cfg!(feature = "sqlite") {
tts.extend(quote!(
impl<'r> sqlx::decode::Decode<'r, sqlx::Sqlite> for #ident where #repr: sqlx::decode::Decode<'r, sqlx::Sqlite> {
fn decode(value: <sqlx::Sqlite as sqlx::database::HasValueRef<'r>>::ValueRef) -> std::result::Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
let value = <#repr as sqlx::decode::Decode<'r, sqlx::Sqlite>>::decode(value)?;
match value {
#(#arms)*
_ => Err(Box::new(sqlx::Error::Decode(format!("invalid value {:?} for enum {}", value, #ident_s).into())))
}
}
}
));
}
Ok(tts)
} }
fn expand_derive_decode_strong_enum( fn expand_derive_decode_strong_enum(

View File

@ -68,28 +68,70 @@ fn expand_derive_encode_transparent(
.params .params
.insert(0, LifetimeDef::new(lifetime.clone()).into()); .insert(0, LifetimeDef::new(lifetime.clone()).into());
generics.params.insert(0, parse_quote!(DB: sqlx::Database));
generics generics
.make_where_clause() .make_where_clause()
.predicates .predicates
.push(parse_quote!(#ty: sqlx::encode::Encode<#lifetime, DB>)); .push(parse_quote!(#ty: sqlx::encode::Encode<#lifetime, DB>));
let (impl_generics, _, where_clause) = generics.split_for_impl();
Ok(quote!( let (impl_generics, _, _) = generics.split_for_impl();
impl #impl_generics sqlx::encode::Encode<#lifetime, DB> for #ident #ty_generics #where_clause {
fn encode_by_ref(&self, buf: &mut <DB as sqlx::database::HasArguments<#lifetime>>::ArgumentBuffer) -> sqlx::encode::IsNull { let mut tts = proc_macro2::TokenStream::new();
<#ty as sqlx::encode::Encode<#lifetime, DB>>::encode_by_ref(&self.0, buf)
if cfg!(feature = "mysql") {
tts.extend(quote!(
impl #impl_generics sqlx::encode::Encode<#lifetime, sqlx::MySql> for #ident #ty_generics where #ty: sqlx::encode::Encode<#lifetime, sqlx::MySql> {
fn encode_by_ref(&self, buf: &mut <sqlx::MySql as sqlx::database::HasArguments<#lifetime>>::ArgumentBuffer) -> sqlx::encode::IsNull {
<#ty as sqlx::encode::Encode<#lifetime, sqlx::MySql>>::encode_by_ref(&self.0, buf)
} }
fn produces(&self) -> Option<DB::TypeInfo> { fn produces(&self) -> Option<sqlx::mysql::MySqlTypeInfo> {
<#ty as sqlx::encode::Encode<#lifetime, DB>>::produces(&self.0) <#ty as sqlx::encode::Encode<#lifetime, sqlx::MySql>>::produces(&self.0)
} }
fn size_hint(&self) -> usize { fn size_hint(&self) -> usize {
<#ty as sqlx::encode::Encode<#lifetime, DB>>::size_hint(&self.0) <#ty as sqlx::encode::Encode<#lifetime, sqlx::MySql>>::size_hint(&self.0)
} }
} }
)) ));
}
if cfg!(feature = "postgres") {
tts.extend(quote!(
impl #impl_generics sqlx::encode::Encode<#lifetime, sqlx::Postgres> for #ident #ty_generics where #ty: sqlx::encode::Encode<#lifetime, sqlx::Postgres> {
fn encode_by_ref(&self, buf: &mut <sqlx::Postgres as sqlx::database::HasArguments<#lifetime>>::ArgumentBuffer) -> sqlx::encode::IsNull {
<#ty as sqlx::encode::Encode<#lifetime, sqlx::Postgres>>::encode_by_ref(&self.0, buf)
}
fn produces(&self) -> Option<sqlx::postgres::PgTypeInfo> {
<#ty as sqlx::encode::Encode<#lifetime, sqlx::Postgres>>::produces(&self.0)
}
fn size_hint(&self) -> usize {
<#ty as sqlx::encode::Encode<#lifetime, sqlx::Postgres>>::size_hint(&self.0)
}
}
));
}
if cfg!(feature = "sqlite") {
tts.extend(quote!(
impl #impl_generics sqlx::encode::Encode<#lifetime, sqlx::Sqlite> for #ident #ty_generics where #ty: sqlx::encode::Encode<#lifetime, sqlx::Sqlite> {
fn encode_by_ref(&self, buf: &mut <sqlx::Sqlite as sqlx::database::HasArguments<#lifetime>>::ArgumentBuffer) -> sqlx::encode::IsNull {
<#ty as sqlx::encode::Encode<#lifetime, sqlx::Sqlite>>::encode_by_ref(&self.0, buf)
}
fn produces(&self) -> Option<sqlx::sqlite::SqliteTypeInfo> {
<#ty as sqlx::encode::Encode<#lifetime, sqlx::Sqlite>>::produces(&self.0)
}
fn size_hint(&self) -> usize {
<#ty as sqlx::encode::Encode<#lifetime, sqlx::Sqlite>>::size_hint(&self.0)
}
}
));
}
Ok(tts)
} }
fn expand_derive_encode_weak_enum( fn expand_derive_encode_weak_enum(
@ -101,21 +143,63 @@ fn expand_derive_encode_weak_enum(
let ident = &input.ident; let ident = &input.ident;
Ok(quote!( let mut tts = proc_macro2::TokenStream::new();
impl<'q, DB: sqlx::Database> sqlx::encode::Encode<'q, DB> for #ident where #repr: sqlx::encode::Encode<'q, DB> {
fn encode_by_ref(&self, buf: &mut <DB as sqlx::database::HasArguments<'q>>::ArgumentBuffer) -> sqlx::encode::IsNull { if cfg!(feature = "mysql") {
<#repr as sqlx::encode::Encode<DB>>::encode_by_ref(&(*self as #repr), buf) tts.extend(quote!(
impl<'q> sqlx::encode::Encode<'q, sqlx::MySql> for #ident where #repr: sqlx::encode::Encode<'q, sqlx::MySql> {
fn encode_by_ref(&self, buf: &mut <sqlx::MySql as sqlx::database::HasArguments<'q>>::ArgumentBuffer) -> sqlx::encode::IsNull {
<#repr as sqlx::encode::Encode<sqlx::MySql>>::encode_by_ref(&(*self as #repr), buf)
} }
fn produces(&self) -> Option<DB::TypeInfo> { fn produces(&self) -> Option<sqlx::mysql::MySqlTypeInfo> {
<#repr as sqlx::encode::Encode<DB>>::produces(&(*self as #repr)) <#repr as sqlx::encode::Encode<sqlx::MySql>>::produces(&(*self as #repr))
} }
fn size_hint(&self) -> usize { fn size_hint(&self) -> usize {
<#repr as sqlx::encode::Encode<DB>>::size_hint(&(*self as #repr)) <#repr as sqlx::encode::Encode<sqlx::MySql>>::size_hint(&(*self as #repr))
} }
} }
)) ));
}
if cfg!(feature = "postgres") {
tts.extend(quote!(
impl<'q> sqlx::encode::Encode<'q, sqlx::Postgres> for #ident where #repr: sqlx::encode::Encode<'q, sqlx::Postgres> {
fn encode_by_ref(&self, buf: &mut <sqlx::Postgres as sqlx::database::HasArguments<'q>>::ArgumentBuffer) -> sqlx::encode::IsNull {
<#repr as sqlx::encode::Encode<sqlx::Postgres>>::encode_by_ref(&(*self as #repr), buf)
}
fn produces(&self) -> Option<sqlx::postgres::PgTypeInfo> {
<#repr as sqlx::encode::Encode<sqlx::Postgres>>::produces(&(*self as #repr))
}
fn size_hint(&self) -> usize {
<#repr as sqlx::encode::Encode<sqlx::Postgres>>::size_hint(&(*self as #repr))
}
}
));
}
if cfg!(feature = "sqlite") {
tts.extend(quote!(
impl<'q> sqlx::encode::Encode<'q, sqlx::Sqlite> for #ident where #repr: sqlx::encode::Encode<'q, sqlx::Sqlite> {
fn encode_by_ref(&self, buf: &mut <sqlx::Sqlite as sqlx::database::HasArguments<'q>>::ArgumentBuffer) -> sqlx::encode::IsNull {
<#repr as sqlx::encode::Encode<sqlx::Sqlite>>::encode_by_ref(&(*self as #repr), buf)
}
fn produces(&self) -> Option<sqlx::sqlite::SqliteTypeInfo> {
<#repr as sqlx::encode::Encode<sqlx::Sqlite>>::produces(&(*self as #repr))
}
fn size_hint(&self) -> usize {
<#repr as sqlx::encode::Encode<sqlx::Sqlite>>::size_hint(&(*self as #repr))
}
}
));
}
Ok(tts)
} }
fn expand_derive_encode_strong_enum( fn expand_derive_encode_strong_enum(
@ -143,14 +227,17 @@ fn expand_derive_encode_strong_enum(
} }
} }
Ok(quote!( let mut tts = proc_macro2::TokenStream::new();
impl<'q, DB: sqlx::Database> sqlx::encode::Encode<'q, DB> for #ident where &'q str: sqlx::encode::Encode<'q, DB> {
fn encode_by_ref(&self, buf: &mut <DB as sqlx::database::HasArguments<'q>>::ArgumentBuffer) -> sqlx::encode::IsNull { if cfg!(feature = "mysql") {
tts.extend(quote!(
impl<'q> sqlx::encode::Encode<'q, sqlx::MySql> for #ident where &'q str: sqlx::encode::Encode<'q, sqlx::MySql> {
fn encode_by_ref(&self, buf: &mut <sqlx::MySql as sqlx::database::HasArguments<'q>>::ArgumentBuffer) -> sqlx::encode::IsNull {
let val = match self { let val = match self {
#(#value_arms)* #(#value_arms)*
}; };
<&str as sqlx::encode::Encode<'q, DB>>::encode(val, buf) <&str as sqlx::encode::Encode<'q, sqlx::MySql>>::encode(val, buf)
} }
fn size_hint(&self) -> usize { fn size_hint(&self) -> usize {
@ -158,10 +245,57 @@ fn expand_derive_encode_strong_enum(
#(#value_arms)* #(#value_arms)*
}; };
<&str as sqlx::encode::Encode<'q, DB>>::size_hint(&val) <&str as sqlx::encode::Encode<'q, sqlx::MySql>>::size_hint(&val)
} }
} }
)) ));
}
if cfg!(feature = "postgres") {
tts.extend(quote!(
impl<'q> sqlx::encode::Encode<'q, sqlx::Postgres> for #ident where &'q str: sqlx::encode::Encode<'q, sqlx::Postgres> {
fn encode_by_ref(&self, buf: &mut <sqlx::Postgres as sqlx::database::HasArguments<'q>>::ArgumentBuffer) -> sqlx::encode::IsNull {
let val = match self {
#(#value_arms)*
};
<&str as sqlx::encode::Encode<'q, sqlx::Postgres>>::encode(val, buf)
}
fn size_hint(&self) -> usize {
let val = match self {
#(#value_arms)*
};
<&str as sqlx::encode::Encode<'q, sqlx::Postgres>>::size_hint(&val)
}
}
));
}
if cfg!(feature = "sqlite") {
tts.extend(quote!(
impl<'q> sqlx::encode::Encode<'q, sqlx::Sqlite> for #ident where &'q str: sqlx::encode::Encode<'q, sqlx::Sqlite> {
fn encode_by_ref(&self, buf: &mut <sqlx::Sqlite as sqlx::database::HasArguments<'q>>::ArgumentBuffer) -> sqlx::encode::IsNull {
let val = match self {
#(#value_arms)*
};
<&str as sqlx::encode::Encode<'q, sqlx::Sqlite>>::encode(val, buf)
}
fn size_hint(&self) -> usize {
let val = match self {
#(#value_arms)*
};
<&str as sqlx::encode::Encode<'q, sqlx::Sqlite>>::size_hint(&val)
}
}
));
}
Ok(tts)
} }
fn expand_derive_encode_struct( fn expand_derive_encode_struct(

View File

@ -59,23 +59,66 @@ fn expand_derive_has_sql_type_transparent(
if attr.transparent { if attr.transparent {
let mut generics = generics.clone(); let mut generics = generics.clone();
generics.params.insert(0, parse_quote!(DB: sqlx::Database));
let mut tts = proc_macro2::TokenStream::new();
if cfg!(feature = "mysql") {
generics generics
.make_where_clause() .make_where_clause()
.predicates .predicates
.push(parse_quote!(#ty: sqlx::Type<DB>)); .push(parse_quote!(#ty: sqlx::Type<sqlx::MySql>));
let (impl_generics, _, where_clause) = generics.split_for_impl(); let (impl_generics, _, where_clause) = generics.split_for_impl();
return Ok(quote!( tts.extend(quote!(
impl #impl_generics sqlx::Type< DB > for #ident #ty_generics #where_clause { impl #impl_generics sqlx::Type<sqlx::MySql> for #ident #ty_generics #where_clause
fn type_info() -> DB::TypeInfo { {
<#ty as sqlx::Type<DB>>::type_info() fn type_info() -> sqlx::mysql::MySqlTypeInfo {
<#ty as sqlx::Type<sqlx::MySql>>::type_info()
} }
} }
)); ));
} }
if cfg!(feature = "postgres") {
generics
.make_where_clause()
.predicates
.push(parse_quote!(#ty: sqlx::Type<sqlx::Postgres>));
let (impl_generics, _, where_clause) = generics.split_for_impl();
tts.extend(quote!(
impl #impl_generics sqlx::Type<sqlx::Postgres> for #ident #ty_generics #where_clause
{
fn type_info() -> sqlx::postgres::PgTypeInfo {
<#ty as sqlx::Type<sqlx::Postgres>>::type_info()
}
}
));
}
if cfg!(feature = "sqlite") {
generics
.make_where_clause()
.predicates
.push(parse_quote!(#ty: sqlx::Type<sqlx::Sqlite>));
let (impl_generics, _, where_clause) = generics.split_for_impl();
tts.extend(quote!(
impl #impl_generics sqlx::Type<sqlx::Sqlite> for #ident #ty_generics #where_clause
{
fn type_info() -> sqlx::sqlite::SqliteTypeInfo {
<#ty as sqlx::Type<sqlx::Sqlite>>::type_info()
}
}
));
}
return Ok(tts);
}
let mut tts = proc_macro2::TokenStream::new(); let mut tts = proc_macro2::TokenStream::new();
if cfg!(feature = "postgres") { if cfg!(feature = "postgres") {
@ -100,18 +143,49 @@ fn expand_derive_has_sql_type_weak_enum(
let attr = check_weak_enum_attributes(input, variants)?; let attr = check_weak_enum_attributes(input, variants)?;
let repr = attr.repr.unwrap(); let repr = attr.repr.unwrap();
let ident = &input.ident; let ident = &input.ident;
let ts = quote!(
impl<DB: sqlx::Database> sqlx::Type<DB> for #ident
where
#repr: sqlx::Type<DB>,
{
fn type_info() -> DB::TypeInfo {
<#repr as sqlx::Type<DB>>::type_info()
}
}
);
Ok(ts) let mut tts = proc_macro2::TokenStream::new();
if cfg!(feature = "mysql") {
tts.extend(quote!(
impl sqlx::Type<sqlx::MySql> for #ident
where
#repr: sqlx::Type<sqlx::MySql>,
{
fn type_info() -> sqlx::mysql::MySqlTypeInfo {
<#repr as sqlx::Type<sqlx::MySql>>::type_info()
}
}
));
}
if cfg!(feature = "postgres") {
tts.extend(quote!(
impl sqlx::Type<sqlx::Postgres> for #ident
where
#repr: sqlx::Type<sqlx::Postgres>,
{
fn type_info() -> sqlx::postgres::PgTypeInfo {
<#repr as sqlx::Type<sqlx::Postgres>>::type_info()
}
}
));
}
if cfg!(feature = "sqlite") {
tts.extend(quote!(
impl sqlx::Type<sqlx::Sqlite> for #ident
where
#repr: sqlx::Type<sqlx::Sqlite>,
{
fn type_info() -> sqlx::sqlite::SqliteTypeInfo {
<#repr as sqlx::Type<sqlx::Sqlite>>::type_info()
}
}
));
}
Ok(tts)
} }
fn expand_derive_has_sql_type_strong_enum( fn expand_derive_has_sql_type_strong_enum(