mirror of
https://github.com/launchbadge/sqlx.git
synced 2025-12-28 11:39:07 +00:00
Allow single-field named structs to be transparent (#3971)
* Allow single-field named structs to be transparent This more closely matches the criteria for e.g. #[repr(transparent)] and #[serde(transparent)]. * Add tests, fix error messages
This commit is contained in:
parent
ff93aa017a
commit
a301d9abad
@ -281,7 +281,7 @@ pub fn check_struct_attributes(
|
||||
|
||||
assert_attribute!(
|
||||
!attributes.transparent,
|
||||
"unexpected #[sqlx(transparent)]",
|
||||
"#[sqlx(transparent)] is only valid for structs with exactly one field",
|
||||
input
|
||||
);
|
||||
|
||||
|
||||
@ -8,18 +8,17 @@ use quote::quote;
|
||||
use syn::punctuated::Punctuated;
|
||||
use syn::token::Comma;
|
||||
use syn::{
|
||||
parse_quote, Arm, Data, DataEnum, DataStruct, DeriveInput, Field, Fields, FieldsNamed,
|
||||
FieldsUnnamed, Stmt, TypeParamBound, Variant,
|
||||
parse_quote, Arm, Data, DataEnum, DataStruct, DeriveInput, Field, Fields, FieldsNamed, Stmt,
|
||||
TypeParamBound, Variant,
|
||||
};
|
||||
|
||||
pub fn expand_derive_decode(input: &DeriveInput) -> syn::Result<TokenStream> {
|
||||
let attrs = parse_container_attributes(&input.attrs)?;
|
||||
match &input.data {
|
||||
Data::Struct(DataStruct {
|
||||
fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }),
|
||||
..
|
||||
}) if unnamed.len() == 1 => {
|
||||
expand_derive_decode_transparent(input, unnamed.first().unwrap())
|
||||
Data::Struct(DataStruct { fields, .. })
|
||||
if fields.len() == 1 && (matches!(fields, Fields::Unnamed(_)) || attrs.transparent) =>
|
||||
{
|
||||
expand_derive_decode_transparent(input, fields.iter().next().unwrap())
|
||||
}
|
||||
Data::Enum(DataEnum { variants, .. }) => match attrs.repr {
|
||||
Some(_) => expand_derive_decode_weak_enum(input, variants),
|
||||
@ -35,7 +34,7 @@ pub fn expand_derive_decode(input: &DeriveInput) -> syn::Result<TokenStream> {
|
||||
..
|
||||
}) => Err(syn::Error::new_spanned(
|
||||
input,
|
||||
"structs with zero or more than one unnamed field are not supported",
|
||||
"tuple structs may only have a single field",
|
||||
)),
|
||||
Data::Struct(DataStruct {
|
||||
fields: Fields::Unit,
|
||||
@ -72,6 +71,12 @@ fn expand_derive_decode_transparent(
|
||||
.push(parse_quote!(#ty: ::sqlx::decode::Decode<'r, DB>));
|
||||
let (impl_generics, _, where_clause) = generics.split_for_impl();
|
||||
|
||||
let field_ident = if let Some(ident) = &field.ident {
|
||||
quote! { #ident }
|
||||
} else {
|
||||
quote! { 0 }
|
||||
};
|
||||
|
||||
let tts = quote!(
|
||||
#[automatically_derived]
|
||||
impl #impl_generics ::sqlx::decode::Decode<'r, DB> for #ident #ty_generics #where_clause {
|
||||
@ -83,7 +88,8 @@ fn expand_derive_decode_transparent(
|
||||
dyn ::std::error::Error + 'static + ::std::marker::Send + ::std::marker::Sync,
|
||||
>,
|
||||
> {
|
||||
<#ty as ::sqlx::decode::Decode<'r, DB>>::decode(value).map(Self)
|
||||
<#ty as ::sqlx::decode::Decode<'r, DB>>::decode(value)
|
||||
.map(|val| Self { #field_ident: val })
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
@ -9,18 +9,17 @@ use syn::punctuated::Punctuated;
|
||||
use syn::token::Comma;
|
||||
use syn::{
|
||||
parse_quote, Data, DataEnum, DataStruct, DeriveInput, Expr, Field, Fields, FieldsNamed,
|
||||
FieldsUnnamed, Lifetime, LifetimeParam, Stmt, TypeParamBound, Variant,
|
||||
Lifetime, LifetimeParam, Stmt, TypeParamBound, Variant,
|
||||
};
|
||||
|
||||
pub fn expand_derive_encode(input: &DeriveInput) -> syn::Result<TokenStream> {
|
||||
let args = parse_container_attributes(&input.attrs)?;
|
||||
|
||||
match &input.data {
|
||||
Data::Struct(DataStruct {
|
||||
fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }),
|
||||
..
|
||||
}) if unnamed.len() == 1 => {
|
||||
expand_derive_encode_transparent(input, unnamed.first().unwrap())
|
||||
Data::Struct(DataStruct { fields, .. })
|
||||
if fields.len() == 1 && (matches!(fields, Fields::Unnamed(_)) || args.transparent) =>
|
||||
{
|
||||
expand_derive_encode_transparent(input, fields.iter().next().unwrap())
|
||||
}
|
||||
Data::Enum(DataEnum { variants, .. }) => match args.repr {
|
||||
Some(_) => expand_derive_encode_weak_enum(input, variants),
|
||||
@ -36,7 +35,7 @@ pub fn expand_derive_encode(input: &DeriveInput) -> syn::Result<TokenStream> {
|
||||
..
|
||||
}) => Err(syn::Error::new_spanned(
|
||||
input,
|
||||
"structs with zero or more than one unnamed field are not supported",
|
||||
"tuple structs may only have a single field",
|
||||
)),
|
||||
Data::Struct(DataStruct {
|
||||
fields: Fields::Unit,
|
||||
@ -77,6 +76,12 @@ fn expand_derive_encode_transparent(
|
||||
.push(parse_quote!(#ty: ::sqlx::encode::Encode<#lifetime, DB>));
|
||||
let (impl_generics, _, where_clause) = generics.split_for_impl();
|
||||
|
||||
let field_ident = if let Some(ident) = &field.ident {
|
||||
quote! { #ident }
|
||||
} else {
|
||||
quote! { 0 }
|
||||
};
|
||||
|
||||
Ok(quote!(
|
||||
#[automatically_derived]
|
||||
impl #impl_generics ::sqlx::encode::Encode<#lifetime, DB> for #ident #ty_generics
|
||||
@ -86,15 +91,15 @@ fn expand_derive_encode_transparent(
|
||||
&self,
|
||||
buf: &mut <DB as ::sqlx::database::Database>::ArgumentBuffer<#lifetime>,
|
||||
) -> ::std::result::Result<::sqlx::encode::IsNull, ::sqlx::error::BoxDynError> {
|
||||
<#ty as ::sqlx::encode::Encode<#lifetime, DB>>::encode_by_ref(&self.0, buf)
|
||||
<#ty as ::sqlx::encode::Encode<#lifetime, DB>>::encode_by_ref(&self.#field_ident, buf)
|
||||
}
|
||||
|
||||
fn produces(&self) -> Option<DB::TypeInfo> {
|
||||
<#ty as ::sqlx::encode::Encode<#lifetime, DB>>::produces(&self.0)
|
||||
<#ty as ::sqlx::encode::Encode<#lifetime, DB>>::produces(&self.#field_ident)
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> usize {
|
||||
<#ty as ::sqlx::encode::Encode<#lifetime, DB>>::size_hint(&self.0)
|
||||
<#ty as ::sqlx::encode::Encode<#lifetime, DB>>::size_hint(&self.#field_ident)
|
||||
}
|
||||
}
|
||||
))
|
||||
|
||||
@ -7,8 +7,7 @@ use quote::{quote, quote_spanned};
|
||||
use syn::punctuated::Punctuated;
|
||||
use syn::token::Comma;
|
||||
use syn::{
|
||||
parse_quote, Data, DataEnum, DataStruct, DeriveInput, Field, Fields, FieldsNamed,
|
||||
FieldsUnnamed, Variant,
|
||||
parse_quote, Data, DataEnum, DataStruct, DeriveInput, Field, Fields, FieldsNamed, Variant,
|
||||
};
|
||||
|
||||
pub fn expand_derive_type(input: &DeriveInput) -> syn::Result<TokenStream> {
|
||||
@ -16,18 +15,11 @@ pub fn expand_derive_type(input: &DeriveInput) -> syn::Result<TokenStream> {
|
||||
match &input.data {
|
||||
// Newtype structs:
|
||||
// struct Foo(i32);
|
||||
Data::Struct(DataStruct {
|
||||
fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }),
|
||||
..
|
||||
}) => {
|
||||
if unnamed.len() == 1 {
|
||||
expand_derive_has_sql_type_transparent(input, unnamed.first().unwrap())
|
||||
} else {
|
||||
Err(syn::Error::new_spanned(
|
||||
input,
|
||||
"structs with zero or more than one unnamed field are not supported",
|
||||
))
|
||||
}
|
||||
// struct Foo { field: i32 };
|
||||
Data::Struct(DataStruct { fields, .. })
|
||||
if fields.len() == 1 && (matches!(fields, Fields::Unnamed(_)) || attrs.transparent) =>
|
||||
{
|
||||
expand_derive_has_sql_type_transparent(input, fields.iter().next().unwrap())
|
||||
}
|
||||
// Record types
|
||||
// struct Foo { foo: i32, bar: String }
|
||||
@ -35,6 +27,13 @@ pub fn expand_derive_type(input: &DeriveInput) -> syn::Result<TokenStream> {
|
||||
fields: Fields::Named(FieldsNamed { named, .. }),
|
||||
..
|
||||
}) => expand_derive_has_sql_type_struct(input, named),
|
||||
Data::Struct(DataStruct {
|
||||
fields: Fields::Unnamed(..),
|
||||
..
|
||||
}) => Err(syn::Error::new_spanned(
|
||||
input,
|
||||
"tuple structs may only have a single field",
|
||||
)),
|
||||
Data::Struct(DataStruct {
|
||||
fields: Fields::Unit,
|
||||
..
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
use sqlx_mysql::MySql;
|
||||
use sqlx_test::new;
|
||||
use sqlx_test::{new, test_type};
|
||||
|
||||
#[sqlx::test]
|
||||
async fn test_derive_strong_enum() -> anyhow::Result<()> {
|
||||
@ -300,3 +300,23 @@ async fn test_derive_weak_enum() -> anyhow::Result<()> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Debug, sqlx::Type)]
|
||||
#[sqlx(transparent)]
|
||||
struct TransparentTuple(i64);
|
||||
|
||||
#[derive(PartialEq, Eq, Debug, sqlx::Type)]
|
||||
#[sqlx(transparent)]
|
||||
struct TransparentNamed {
|
||||
field: i64,
|
||||
}
|
||||
|
||||
test_type!(transparent_tuple<TransparentTuple>(MySql,
|
||||
"0" == TransparentTuple(0),
|
||||
"23523" == TransparentTuple(23523)
|
||||
));
|
||||
|
||||
test_type!(transparent_named<TransparentNamed>(MySql,
|
||||
"0" == TransparentNamed { field: 0 },
|
||||
"23523" == TransparentNamed { field: 23523 },
|
||||
));
|
||||
|
||||
@ -12,6 +12,13 @@ use std::ops::Bound;
|
||||
#[sqlx(transparent)]
|
||||
struct Transparent(i32);
|
||||
|
||||
// Also possible for single-field named structs
|
||||
#[derive(PartialEq, Debug, sqlx::Type)]
|
||||
#[sqlx(transparent)]
|
||||
struct TransparentNamed {
|
||||
field: i32,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Debug, sqlx::Type)]
|
||||
// https://github.com/launchbadge/sqlx/issues/2611
|
||||
// Previously, the derive would generate a `PgHasArrayType` impl that errored on an
|
||||
@ -143,11 +150,16 @@ struct FloatRange(PgRange<f64>);
|
||||
#[sqlx(type_name = "int4rangeL0pC")]
|
||||
struct RangeInclusive(PgRange<i32>);
|
||||
|
||||
test_type!(transparent<Transparent>(Postgres,
|
||||
test_type!(transparent_tuple<Transparent>(Postgres,
|
||||
"0" == Transparent(0),
|
||||
"23523" == Transparent(23523)
|
||||
));
|
||||
|
||||
test_type!(transparent_named<TransparentNamed>(Postgres,
|
||||
"0" == TransparentNamed { field: 0 },
|
||||
"23523" == TransparentNamed { field: 23523 },
|
||||
));
|
||||
|
||||
test_type!(transparent_array<TransparentArray>(Postgres,
|
||||
"'{}'::int8[]" == TransparentArray(vec![]),
|
||||
"'{ 23523, 123456, 789 }'::int8[]" == TransparentArray(vec![23523, 123456, 789])
|
||||
|
||||
@ -12,3 +12,23 @@ test_type!(origin_enum<Origin>(Sqlite,
|
||||
"1" == Origin::Foo,
|
||||
"2" == Origin::Bar,
|
||||
));
|
||||
|
||||
#[derive(PartialEq, Eq, Debug, sqlx::Type)]
|
||||
#[sqlx(transparent)]
|
||||
struct TransparentTuple(i64);
|
||||
|
||||
#[derive(PartialEq, Eq, Debug, sqlx::Type)]
|
||||
#[sqlx(transparent)]
|
||||
struct TransparentNamed {
|
||||
field: i64,
|
||||
}
|
||||
|
||||
test_type!(transparent_tuple<TransparentTuple>(Sqlite,
|
||||
"0" == TransparentTuple(0),
|
||||
"23523" == TransparentTuple(23523)
|
||||
));
|
||||
|
||||
test_type!(transparent_named<TransparentNamed>(Sqlite,
|
||||
"0" == TransparentNamed { field: 0 },
|
||||
"23523" == TransparentNamed { field: 23523 },
|
||||
));
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user