Allow single-field named structs to be transparent (#3971)

* Allow single-field named structs to be transparent

This more closely matches the criteria for e.g. #[repr(transparent)]
and #[serde(transparent)].

* Add tests, fix error messages
This commit is contained in:
Xiretza 2025-08-21 20:57:35 +00:00 committed by GitHub
parent ff93aa017a
commit a301d9abad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 98 additions and 36 deletions

View File

@ -281,7 +281,7 @@ pub fn check_struct_attributes(
assert_attribute!(
!attributes.transparent,
"unexpected #[sqlx(transparent)]",
"#[sqlx(transparent)] is only valid for structs with exactly one field",
input
);

View File

@ -8,18 +8,17 @@ use quote::quote;
use syn::punctuated::Punctuated;
use syn::token::Comma;
use syn::{
parse_quote, Arm, Data, DataEnum, DataStruct, DeriveInput, Field, Fields, FieldsNamed,
FieldsUnnamed, Stmt, TypeParamBound, Variant,
parse_quote, Arm, Data, DataEnum, DataStruct, DeriveInput, Field, Fields, FieldsNamed, Stmt,
TypeParamBound, Variant,
};
pub fn expand_derive_decode(input: &DeriveInput) -> syn::Result<TokenStream> {
let attrs = parse_container_attributes(&input.attrs)?;
match &input.data {
Data::Struct(DataStruct {
fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }),
..
}) if unnamed.len() == 1 => {
expand_derive_decode_transparent(input, unnamed.first().unwrap())
Data::Struct(DataStruct { fields, .. })
if fields.len() == 1 && (matches!(fields, Fields::Unnamed(_)) || attrs.transparent) =>
{
expand_derive_decode_transparent(input, fields.iter().next().unwrap())
}
Data::Enum(DataEnum { variants, .. }) => match attrs.repr {
Some(_) => expand_derive_decode_weak_enum(input, variants),
@ -35,7 +34,7 @@ pub fn expand_derive_decode(input: &DeriveInput) -> syn::Result<TokenStream> {
..
}) => Err(syn::Error::new_spanned(
input,
"structs with zero or more than one unnamed field are not supported",
"tuple structs may only have a single field",
)),
Data::Struct(DataStruct {
fields: Fields::Unit,
@ -72,6 +71,12 @@ fn expand_derive_decode_transparent(
.push(parse_quote!(#ty: ::sqlx::decode::Decode<'r, DB>));
let (impl_generics, _, where_clause) = generics.split_for_impl();
let field_ident = if let Some(ident) = &field.ident {
quote! { #ident }
} else {
quote! { 0 }
};
let tts = quote!(
#[automatically_derived]
impl #impl_generics ::sqlx::decode::Decode<'r, DB> for #ident #ty_generics #where_clause {
@ -83,7 +88,8 @@ fn expand_derive_decode_transparent(
dyn ::std::error::Error + 'static + ::std::marker::Send + ::std::marker::Sync,
>,
> {
<#ty as ::sqlx::decode::Decode<'r, DB>>::decode(value).map(Self)
<#ty as ::sqlx::decode::Decode<'r, DB>>::decode(value)
.map(|val| Self { #field_ident: val })
}
}
);

View File

@ -9,18 +9,17 @@ use syn::punctuated::Punctuated;
use syn::token::Comma;
use syn::{
parse_quote, Data, DataEnum, DataStruct, DeriveInput, Expr, Field, Fields, FieldsNamed,
FieldsUnnamed, Lifetime, LifetimeParam, Stmt, TypeParamBound, Variant,
Lifetime, LifetimeParam, Stmt, TypeParamBound, Variant,
};
pub fn expand_derive_encode(input: &DeriveInput) -> syn::Result<TokenStream> {
let args = parse_container_attributes(&input.attrs)?;
match &input.data {
Data::Struct(DataStruct {
fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }),
..
}) if unnamed.len() == 1 => {
expand_derive_encode_transparent(input, unnamed.first().unwrap())
Data::Struct(DataStruct { fields, .. })
if fields.len() == 1 && (matches!(fields, Fields::Unnamed(_)) || args.transparent) =>
{
expand_derive_encode_transparent(input, fields.iter().next().unwrap())
}
Data::Enum(DataEnum { variants, .. }) => match args.repr {
Some(_) => expand_derive_encode_weak_enum(input, variants),
@ -36,7 +35,7 @@ pub fn expand_derive_encode(input: &DeriveInput) -> syn::Result<TokenStream> {
..
}) => Err(syn::Error::new_spanned(
input,
"structs with zero or more than one unnamed field are not supported",
"tuple structs may only have a single field",
)),
Data::Struct(DataStruct {
fields: Fields::Unit,
@ -77,6 +76,12 @@ fn expand_derive_encode_transparent(
.push(parse_quote!(#ty: ::sqlx::encode::Encode<#lifetime, DB>));
let (impl_generics, _, where_clause) = generics.split_for_impl();
let field_ident = if let Some(ident) = &field.ident {
quote! { #ident }
} else {
quote! { 0 }
};
Ok(quote!(
#[automatically_derived]
impl #impl_generics ::sqlx::encode::Encode<#lifetime, DB> for #ident #ty_generics
@ -86,15 +91,15 @@ fn expand_derive_encode_transparent(
&self,
buf: &mut <DB as ::sqlx::database::Database>::ArgumentBuffer<#lifetime>,
) -> ::std::result::Result<::sqlx::encode::IsNull, ::sqlx::error::BoxDynError> {
<#ty as ::sqlx::encode::Encode<#lifetime, DB>>::encode_by_ref(&self.0, buf)
<#ty as ::sqlx::encode::Encode<#lifetime, DB>>::encode_by_ref(&self.#field_ident, buf)
}
fn produces(&self) -> Option<DB::TypeInfo> {
<#ty as ::sqlx::encode::Encode<#lifetime, DB>>::produces(&self.0)
<#ty as ::sqlx::encode::Encode<#lifetime, DB>>::produces(&self.#field_ident)
}
fn size_hint(&self) -> usize {
<#ty as ::sqlx::encode::Encode<#lifetime, DB>>::size_hint(&self.0)
<#ty as ::sqlx::encode::Encode<#lifetime, DB>>::size_hint(&self.#field_ident)
}
}
))

View File

@ -7,8 +7,7 @@ use quote::{quote, quote_spanned};
use syn::punctuated::Punctuated;
use syn::token::Comma;
use syn::{
parse_quote, Data, DataEnum, DataStruct, DeriveInput, Field, Fields, FieldsNamed,
FieldsUnnamed, Variant,
parse_quote, Data, DataEnum, DataStruct, DeriveInput, Field, Fields, FieldsNamed, Variant,
};
pub fn expand_derive_type(input: &DeriveInput) -> syn::Result<TokenStream> {
@ -16,18 +15,11 @@ pub fn expand_derive_type(input: &DeriveInput) -> syn::Result<TokenStream> {
match &input.data {
// Newtype structs:
// struct Foo(i32);
Data::Struct(DataStruct {
fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }),
..
}) => {
if unnamed.len() == 1 {
expand_derive_has_sql_type_transparent(input, unnamed.first().unwrap())
} else {
Err(syn::Error::new_spanned(
input,
"structs with zero or more than one unnamed field are not supported",
))
}
// struct Foo { field: i32 };
Data::Struct(DataStruct { fields, .. })
if fields.len() == 1 && (matches!(fields, Fields::Unnamed(_)) || attrs.transparent) =>
{
expand_derive_has_sql_type_transparent(input, fields.iter().next().unwrap())
}
// Record types
// struct Foo { foo: i32, bar: String }
@ -35,6 +27,13 @@ pub fn expand_derive_type(input: &DeriveInput) -> syn::Result<TokenStream> {
fields: Fields::Named(FieldsNamed { named, .. }),
..
}) => expand_derive_has_sql_type_struct(input, named),
Data::Struct(DataStruct {
fields: Fields::Unnamed(..),
..
}) => Err(syn::Error::new_spanned(
input,
"tuple structs may only have a single field",
)),
Data::Struct(DataStruct {
fields: Fields::Unit,
..

View File

@ -1,5 +1,5 @@
use sqlx_mysql::MySql;
use sqlx_test::new;
use sqlx_test::{new, test_type};
#[sqlx::test]
async fn test_derive_strong_enum() -> anyhow::Result<()> {
@ -300,3 +300,23 @@ async fn test_derive_weak_enum() -> anyhow::Result<()> {
Ok(())
}
#[derive(PartialEq, Eq, Debug, sqlx::Type)]
#[sqlx(transparent)]
struct TransparentTuple(i64);
#[derive(PartialEq, Eq, Debug, sqlx::Type)]
#[sqlx(transparent)]
struct TransparentNamed {
field: i64,
}
test_type!(transparent_tuple<TransparentTuple>(MySql,
"0" == TransparentTuple(0),
"23523" == TransparentTuple(23523)
));
test_type!(transparent_named<TransparentNamed>(MySql,
"0" == TransparentNamed { field: 0 },
"23523" == TransparentNamed { field: 23523 },
));

View File

@ -12,6 +12,13 @@ use std::ops::Bound;
#[sqlx(transparent)]
struct Transparent(i32);
// Also possible for single-field named structs
#[derive(PartialEq, Debug, sqlx::Type)]
#[sqlx(transparent)]
struct TransparentNamed {
field: i32,
}
#[derive(PartialEq, Debug, sqlx::Type)]
// https://github.com/launchbadge/sqlx/issues/2611
// Previously, the derive would generate a `PgHasArrayType` impl that errored on an
@ -143,11 +150,16 @@ struct FloatRange(PgRange<f64>);
#[sqlx(type_name = "int4rangeL0pC")]
struct RangeInclusive(PgRange<i32>);
test_type!(transparent<Transparent>(Postgres,
test_type!(transparent_tuple<Transparent>(Postgres,
"0" == Transparent(0),
"23523" == Transparent(23523)
));
test_type!(transparent_named<TransparentNamed>(Postgres,
"0" == TransparentNamed { field: 0 },
"23523" == TransparentNamed { field: 23523 },
));
test_type!(transparent_array<TransparentArray>(Postgres,
"'{}'::int8[]" == TransparentArray(vec![]),
"'{ 23523, 123456, 789 }'::int8[]" == TransparentArray(vec![23523, 123456, 789])

View File

@ -12,3 +12,23 @@ test_type!(origin_enum<Origin>(Sqlite,
"1" == Origin::Foo,
"2" == Origin::Bar,
));
#[derive(PartialEq, Eq, Debug, sqlx::Type)]
#[sqlx(transparent)]
struct TransparentTuple(i64);
#[derive(PartialEq, Eq, Debug, sqlx::Type)]
#[sqlx(transparent)]
struct TransparentNamed {
field: i64,
}
test_type!(transparent_tuple<TransparentTuple>(Sqlite,
"0" == TransparentTuple(0),
"23523" == TransparentTuple(23523)
));
test_type!(transparent_named<TransparentNamed>(Sqlite,
"0" == TransparentNamed { field: 0 },
"23523" == TransparentNamed { field: 23523 },
));