From 9134cff1557f7e6ce5a6419e417d434beaf309b9 Mon Sep 17 00:00:00 2001 From: Erick Tryzelaar Date: Sun, 8 Mar 2015 11:39:20 -0700 Subject: [PATCH] #[derive_deserialize] for generic tuple structs --- serde2/serde2_macros/src/lib.rs | 102 +++++++++++++++++++++++++------- serde2/tests/test_macros.rs | 31 ++++++++-- 2 files changed, 107 insertions(+), 26 deletions(-) diff --git a/serde2/serde2_macros/src/lib.rs b/serde2/serde2_macros/src/lib.rs index 0db85d69..c28438e0 100644 --- a/serde2/serde2_macros/src/lib.rs +++ b/serde2/serde2_macros/src/lib.rs @@ -120,10 +120,12 @@ fn expand_derive_serialize( trait_def.expand(cx, mitem, item, |item| push(item)) } -fn serialize_substructure(cx: &ExtCtxt, - span: Span, - substr: &Substructure, - item: &Item) -> P { +fn serialize_substructure( + cx: &ExtCtxt, + span: Span, + substr: &Substructure, + item: &Item, +) -> P { let ctx = aster::Ctx::new(); let builder = aster::AstBuilder::new(&ctx).span(span); @@ -635,7 +637,7 @@ pub fn expand_derive_deserialize( ), attributes: attrs, combine_substructure: combine_substructure(Box::new(|a, b, c| { - deserialize_substructure(a, b, c) + deserialize_substructure(a, b, c, item) })), }) }; @@ -643,11 +645,16 @@ pub fn expand_derive_deserialize( trait_def.expand(cx, mitem, item, |item| push(item)) } -fn deserialize_substructure(cx: &ExtCtxt, span: Span, substr: &Substructure) -> P { +fn deserialize_substructure( + cx: &ExtCtxt, + span: Span, + substr: &Substructure, + item: &Item, +) -> P { let state = substr.nonself_args[0].clone(); - match *substr.fields { - StaticStruct(ref struct_def, ref fields) => { + match (&item.node, &*substr.fields) { + (&ast::ItemStruct(_, ref generics), &StaticStruct(ref struct_def, ref fields)) => { deserialize_struct( cx, span, @@ -656,16 +663,20 @@ fn deserialize_substructure(cx: &ExtCtxt, span: Span, substr: &Substructure) -> cx.path(span, vec![substr.type_ident]), fields, state, - struct_def) + struct_def, + generics, + ) } - StaticEnum(ref enum_def, ref fields) => { + (&ast::ItemEnum(_, ref generics), &StaticEnum(ref enum_def, ref fields)) => { deserialize_enum( cx, span, substr.type_ident, &fields, state, - enum_def) + enum_def, + generics, + ) } _ => cx.bug("expected StaticEnum or StaticStruct in derive(Deserialize)") } @@ -679,7 +690,8 @@ fn deserialize_struct( struct_path: ast::Path, fields: &StaticFields, state: P, - struct_def: &StructDef + struct_def: &StructDef, + generics: &ast::Generics, ) -> P { match *fields { Unnamed(ref fields) => { @@ -699,7 +711,9 @@ fn deserialize_struct( struct_ident, struct_path, &fields, - state) + state, + generics, + ) } } Named(ref fields) => { @@ -774,11 +788,21 @@ fn deserialize_struct_unnamed_fields( struct_path: ast::Path, fields: &[Span], state: P, + generics: &ast::Generics, ) -> P { - let struct_name = cx.expr_str(span, token::get_ident(struct_ident)); + let ctx = aster::Ctx::new(); + let builder = aster::AstBuilder::new(&ctx).span(span); + + let visitor_impl_generics = builder.from_generics(generics.clone()) + .add_ty_param_bound( + builder.path().global().ids(&["serde2", "de", "Deserialize"]).build() + ) + .build(); + + let struct_name = builder.expr().str(struct_ident); let field_names: Vec = (0 .. fields.len()) - .map(|i| token::str_to_ident(&format!("__field{}", i))) + .map(|i| builder.id(&format!("__field{}", i))) .collect(); let visit_seq_expr = declare_visit_seq( @@ -788,13 +812,48 @@ fn deserialize_struct_unnamed_fields( &field_names, ); + // Build `__Visitor(PhantomData, PhantomData, ...)` + let (visitor_struct, visitor_expr) = if generics.ty_params.is_empty() { + ( + builder.item().tuple_struct("__Visitor") + .build(), + builder.expr().id("__Visitor"), + ) + } else { + ( + builder.item().tuple_struct("__Visitor") + .with_generics(generics.clone()) + .with_tys( + generics.ty_params.iter().map(|ty_param| { + builder.ty().phantom_data().id(ty_param.ident) + }) + ) + .build(), + builder.expr().call().id("__Visitor") + .with_args( + generics.ty_params.iter().map(|_| { + builder.expr().phantom_data() + }) + ) + .build(), + ) + }; + + let visitor_ty = builder.ty().path() + .segment("__Visitor").with_generics(generics.clone()).build() + .build(); + + let value_ty = builder.ty().path() + .segment(type_ident).with_generics(generics.clone()).build() + .build(); + quote_expr!(cx, { - struct __Visitor; + $visitor_struct; - impl ::serde2::de::Visitor for __Visitor { - type Value = $type_ident; + impl $visitor_impl_generics ::serde2::de::Visitor for $visitor_ty { + type Value = $value_ty; - fn visit_seq<__V>(&mut self, mut visitor: __V) -> Result<$type_ident, __V::Error> + fn visit_seq<__V>(&mut self, mut visitor: __V) -> Result<$value_ty, __V::Error> where __V: ::serde2::de::SeqVisitor, { $visit_seq_expr @@ -802,7 +861,7 @@ fn deserialize_struct_unnamed_fields( fn visit_named_seq<__V>(&mut self, name: &str, - visitor: __V) -> Result<$type_ident, __V::Error> + visitor: __V) -> Result<$value_ty, __V::Error> where __V: ::serde2::de::SeqVisitor, { if name == $struct_name { @@ -813,7 +872,7 @@ fn deserialize_struct_unnamed_fields( } } - $state.visit(__Visitor) + $state.visit($visitor_expr) }) } @@ -1132,6 +1191,7 @@ fn deserialize_enum( fields: &[(Ident, Span, StaticFields)], state: P, enum_def: &EnumDef, + _generics: &ast::Generics, ) -> P { let type_name = cx.expr_str(span, token::get_ident(type_ident)); diff --git a/serde2/tests/test_macros.rs b/serde2/tests/test_macros.rs index 2ff1d8ee..30a73e66 100644 --- a/serde2/tests/test_macros.rs +++ b/serde2/tests/test_macros.rs @@ -30,10 +30,6 @@ trait Trait { #[derive_deserialize] struct NamedUnit; -#[derive(Debug, PartialEq)] -#[derive_serialize] -struct NamedTuple<'a, 'b, A: 'a, B: 'b, C>(&'a A, &'b mut B, C); - #[derive(Debug, PartialEq)] #[derive_serialize] struct NamedMap<'a, 'b, A: 'a, B: 'b, C> { @@ -87,7 +83,11 @@ fn test_named_unit() { } #[test] -fn test_named_tuple() { +fn test_ser_named_tuple() { + #[derive(Debug, PartialEq)] + #[derive_serialize] + struct NamedTuple<'a, 'b, A: 'a, B: 'b, C>(&'a A, &'b mut B, C); + let a = 5; let mut b = 6; let c = 7; @@ -104,6 +104,27 @@ fn test_named_tuple() { ); } +#[test] +fn test_de_named_tuple() { + #[derive(Debug, PartialEq)] + #[derive_deserialize] + struct NamedTuple(A, B, C); + + assert_eq!( + json::from_str("[1,2,3]").unwrap(), + NamedTuple(1, 2, 3) + ); + + assert_eq!( + json::from_str("[1,2,3]").unwrap(), + Value::Array(vec![ + Value::I64(1), + Value::I64(2), + Value::I64(3), + ]) + ); +} + #[test] fn test_named_map() { let a = 5;