From ec3af2cb6aaf1636f1eb08ca54ee21ccecdc2031 Mon Sep 17 00:00:00 2001 From: Hugo Duncan Date: Fri, 1 May 2015 12:53:59 -0400 Subject: [PATCH] Factor out attr module Factors out field attribute code into the attr module. --- serde_macros/src/de.rs | 199 +++++++++++++++++++++++--------------- serde_macros/src/field.rs | 18 ++-- serde_macros/src/lib.rs | 5 +- serde_macros/src/ser.rs | 20 +--- src/ser/mod.rs | 2 +- tests/test_annotations.rs | 22 +++++ 6 files changed, 154 insertions(+), 112 deletions(-) diff --git a/serde_macros/src/de.rs b/serde_macros/src/de.rs index 08dfa841..541c7122 100644 --- a/serde_macros/src/de.rs +++ b/serde_macros/src/de.rs @@ -12,10 +12,12 @@ use syntax::ast; use syntax::codemap::Span; use syntax::ext::base::ExtCtxt; use syntax::ext::build::AstBuilder; +use syntax::owned_slice::OwnedSlice; use syntax::ptr::P; use aster; +use attr; use field; pub fn expand_derive_deserialize( @@ -156,14 +158,25 @@ fn deserialize_item_struct( fn deserialize_visitor( builder: &aster::AstBuilder, trait_generics: &ast::Generics, -) -> (P, P, P) { - if trait_generics.ty_params.is_empty() { + forward_ty_params: Vec, + forward_tys: Vec> +) -> (P, P, P, ast::Generics) { + if trait_generics.ty_params.is_empty() && forward_tys.is_empty() { ( - builder.item().tuple_struct("__Visitor").build(), + builder.item().tuple_struct("__Visitor").build(), builder.ty().id("__Visitor"), builder.expr().id("__Visitor"), + trait_generics.clone(), ) } else { + let placeholders : Vec<_> = trait_generics.ty_params.iter() + .map(|_| builder.ty().id("_")) + .collect(); + let mut trait_generics = trait_generics.clone(); + let mut ty_params = forward_ty_params.clone(); + ty_params.extend(trait_generics.ty_params.into_vec()); + trait_generics.ty_params = OwnedSlice::from_vec(ty_params); + ( builder.item().tuple_struct("__Visitor") .generics().with(trait_generics.clone()).build() @@ -176,17 +189,37 @@ fn deserialize_visitor( builder.ty().path() .segment("__Visitor").with_generics(trait_generics.clone()).build() .build(), - builder.expr().call().id("__Visitor") + builder.expr().call() + .path().segment("__Visitor") + .with_tys(forward_tys) + .with_tys(placeholders) + .build().build() .with_args( trait_generics.ty_params.iter().map(|_| { builder.expr().phantom_data() }) ) .build(), + trait_generics, ) } } +fn deserializer_ty_param(builder: &aster::AstBuilder) -> ast::TyParam { + builder.ty_param("__D") + .trait_bound(builder.path() + .segment("serde").build() + .segment("de").build() + .id("Deserializer") + .build()) + .build() + .build() +} + +fn deserializer_ty_arg(builder: &aster::AstBuilder) -> P{ + builder.ty().id("__D") +} + fn deserialize_unit_struct( cx: &ExtCtxt, builder: &aster::AstBuilder, @@ -230,10 +263,13 @@ fn deserialize_tuple_struct( ) -> P { let where_clause = &impl_generics.where_clause; - let (visitor_item, visitor_ty, visitor_expr) = deserialize_visitor( - builder, - impl_generics, - ); + let (visitor_item, visitor_ty, visitor_expr, visitor_generics) = + deserialize_visitor( + builder, + impl_generics, + vec![deserializer_ty_param(builder)], + vec![deserializer_ty_arg(builder)], + ); let visit_seq_expr = deserialize_seq( cx, @@ -247,7 +283,7 @@ fn deserialize_tuple_struct( quote_expr!(cx, { $visitor_item - impl $impl_generics ::serde::de::Visitor for $visitor_ty $where_clause { + impl $visitor_generics ::serde::de::Visitor for $visitor_ty $where_clause { type Value = $ty; fn visit_seq<__V>(&mut self, mut visitor: __V) -> ::std::result::Result<$ty, __V::Error> @@ -305,10 +341,13 @@ fn deserialize_struct( ) -> P { let where_clause = &impl_generics.where_clause; - let (visitor_item, visitor_ty, visitor_expr) = deserialize_visitor( - builder, - impl_generics, - ); + let (visitor_item, visitor_ty, visitor_expr, visitor_generics) = + deserialize_visitor( + builder, + &impl_generics, + vec![deserializer_ty_param(builder)], + vec![deserializer_ty_arg(builder)], + ); let (field_visitor, visit_map_expr) = deserialize_struct_visitor( cx, @@ -324,7 +363,7 @@ fn deserialize_struct( $visitor_item - impl $impl_generics ::serde::de::Visitor for $visitor_ty $where_clause { + impl $visitor_generics ::serde::de::Visitor for $visitor_ty $where_clause { type Value = $ty; #[inline] @@ -356,7 +395,7 @@ fn deserialize_item_enum( builder, enum_def.variants.iter() .map(|variant| - field::FieldLit::Global(builder.expr().str(variant.node.name))) + attr::FieldAttrs::Global(builder.expr().str(variant.node.name))) .collect() ); @@ -381,17 +420,20 @@ fn deserialize_item_enum( }) .collect(); - let (visitor_item, visitor_ty, visitor_expr) = deserialize_visitor( - builder, - impl_generics, - ); + let (visitor_item, visitor_ty, visitor_expr, visitor_generics) = + deserialize_visitor( + builder, + impl_generics, + vec![deserializer_ty_param(builder)], + vec![deserializer_ty_arg(builder)], + ); quote_expr!(cx, { $variant_visitor $visitor_item - impl $impl_generics ::serde::de::EnumVisitor for $visitor_ty $where_clause { + impl $visitor_generics ::serde::de::EnumVisitor for $visitor_ty $where_clause { type Value = $ty; fn visit<__V>(&mut self, mut visitor: __V) -> ::std::result::Result<$ty, __V::Error> @@ -460,10 +502,13 @@ fn deserialize_tuple_variant( ) -> P { let where_clause = &generics.where_clause; - let (visitor_item, visitor_ty, visitor_expr) = deserialize_visitor( - builder, - generics, - ); + let (visitor_item, visitor_ty, visitor_expr, visitor_generics) = + deserialize_visitor( + builder, + generics, + vec![deserializer_ty_param(builder)], + vec![deserializer_ty_arg(builder)], + ); let visit_seq_expr = deserialize_seq( cx, @@ -475,7 +520,7 @@ fn deserialize_tuple_variant( quote_expr!(cx, { $visitor_item - impl $generics ::serde::de::Visitor for $visitor_ty $where_clause { + impl $visitor_generics ::serde::de::Visitor for $visitor_ty $where_clause { type Value = $ty; fn visit_seq<__V>(&mut self, mut visitor: __V) -> ::std::result::Result<$ty, __V::Error> @@ -507,17 +552,20 @@ fn deserialize_struct_variant( builder.path().id(type_ident).id(variant_ident).build(), ); - let (visitor_item, visitor_ty, visitor_expr) = deserialize_visitor( - builder, - generics, - ); + let (visitor_item, visitor_ty, visitor_expr, visitor_generics) = + deserialize_visitor( + builder, + generics, + vec![deserializer_ty_param(builder)], + vec![deserializer_ty_arg(builder)], + ); quote_expr!(cx, { $field_visitor $visitor_item - impl $generics ::serde::de::Visitor for $visitor_ty $where_clause { + impl $visitor_generics ::serde::de::Visitor for $visitor_ty $where_clause { type Value = $ty; fn visit_map<__V>(&mut self, mut visitor: __V) -> ::std::result::Result<$ty, __V::Error> @@ -534,10 +582,10 @@ fn deserialize_struct_variant( fn deserialize_field_visitor( cx: &ExtCtxt, builder: &aster::AstBuilder, - field_exprs: Vec, + field_attrs: Vec, ) -> Vec> { // Create the field names for the fields. - let field_idents: Vec = (0 .. field_exprs.len()) + let field_idents: Vec = (0 .. field_attrs.len()) .map(|i| builder.id(format!("__field{}", i))) .collect(); @@ -551,57 +599,43 @@ fn deserialize_field_visitor( ) .build(); - let fmts = field_exprs.iter() - .fold(HashSet::new(), |mut set, field_expr| - match field_expr { - &field::FieldLit::Format{ref formats, default: _} => { - for (fmt, _) in formats.iter() { - set.insert(fmt.clone()); - }; - set - }, - _ => set - }); + // A set of all the formats that have specialized field attributes + let formats = field_attrs.iter() + .fold(HashSet::new(), |mut set, field_expr| { + set.extend(field_expr.formats()); + set + }); // Match arms to extract a field from a string let default_field_arms: Vec<_> = field_idents.iter() - .zip(field_exprs.iter()) + .zip(field_attrs.iter()) .map(|(field_ident, field_expr)| { - match field_expr { - &field::FieldLit::Global(ref expr) => - quote_arm!(cx, $expr => { Ok(__Field::$field_ident) }), - &field::FieldLit::Format{formats: _, ref default} => - quote_arm!(cx, $default => { Ok(__Field::$field_ident)}) - } + let expr = field_expr.default_key_expr(); + quote_arm!(cx, $expr => { Ok(__Field::$field_ident) }) }) .collect(); - let body = if fmts.is_empty() { + let body = if formats.is_empty() { + // No formats specific attributes, so no match on format required quote_expr!(cx, match value { $default_field_arms, _ => Err(::serde::de::Error::unknown_field_error(value)), }) } else { - let field_arms : Vec<_> = fmts.iter() + let field_arms : Vec<_> = formats.iter() .map(|fmt| { field_idents.iter() - .zip(field_exprs.iter()) + .zip(field_attrs.iter()) .map(|(field_ident, field_expr)| { - match field_expr { - &field::FieldLit::Global(ref expr) => - quote_arm!(cx, - $expr => { Ok(__Field::$field_ident) }), - &field::FieldLit::Format{ref formats, ref default} => { - let expr = formats.get(fmt).unwrap_or(default); - quote_arm!(cx, - $expr => { Ok(__Field::$field_ident) })}} + let expr = field_expr.key_expr(fmt); + quote_arm!(cx, $expr => { Ok(__Field::$field_ident) }) }) .collect::>() }) .collect(); - let fmt_matches : Vec<_> = fmts.iter() + let fmt_matches : Vec<_> = formats.iter() .zip(field_arms.iter()) .map(|(ref fmt, ref arms)| { quote_arm!(cx, $fmt => { @@ -615,7 +649,7 @@ fn deserialize_field_visitor( .collect(); quote_expr!(cx, - match D::format() { + match __D::format() { $fmt_matches, _ => match value { $default_field_arms, @@ -639,8 +673,8 @@ fn deserialize_field_visitor( phantom: PhantomData } - impl ::serde::de::Visitor for __FieldVisitor - where D: ::serde::de::Deserializer + impl<__D> ::serde::de::Visitor for __FieldVisitor<__D> + where __D: ::serde::de::Deserializer { type Value = __Field; @@ -710,23 +744,30 @@ fn deserialize_map( let extract_values: Vec> = field_names.iter() .zip(struct_def.fields.iter()) - .map(|(field_name, field)| { - let rename = field::field_rename(builder, field); - let name_str = match (rename, field.node.kind) { - (field::Rename::Global(rename), _) - => builder.expr().build_lit(P(rename.clone())), - (field::Rename::None, ast::NamedField(name, _)) - => builder.expr().str(name), - (field::Rename::None, ast::UnnamedField(_)) - => panic!("struct contains unnamed fields"), - (field::Rename::Format(renames), _) - => builder.expr().str("fixme"), - }; - + .zip(field::struct_field_strs(cx, builder, struct_def).iter()) + .map(|((field_name, field), field_attr)| { let missing_expr = if field::default_value(field) { quote_expr!(cx, ::std::default::Default::default()) } else { - quote_expr!(cx, try!(visitor.missing_field($name_str))) + let formats = field_attr.formats(); + let arms : Vec<_> = formats.iter() + .map(|format| { + let key_expr = field_attr.key_expr(format); + quote_arm!(cx, $format => { $key_expr }) + }) + .collect(); + let default = field_attr.default_key_expr(); + if arms.is_empty() { + quote_expr!(cx, try!(visitor.missing_field($default))) + } else { + quote_expr!( + cx, + try!(visitor.missing_field( + match __D::format() { + $arms, + _ => $default + }))) + } }; quote_stmt!(cx, diff --git a/serde_macros/src/field.rs b/serde_macros/src/field.rs index 6c2a9cda..dde2890e 100644 --- a/serde_macros/src/field.rs +++ b/serde_macros/src/field.rs @@ -7,6 +7,8 @@ use syntax::ptr::P; use aster; +use attr::FieldAttrs; + pub enum Rename<'a> { None, Global(&'a ast::Lit), @@ -65,24 +67,16 @@ pub fn field_rename<'a>( .unwrap_or(Rename::None) } -pub enum FieldLit { - Global(P), - Format{ - formats: HashMap, P>, - default: P, - } -} - pub fn struct_field_strs( cx: &ExtCtxt, builder: &aster::AstBuilder, struct_def: &ast::StructDef, -) -> Vec { +) -> Vec { struct_def.fields.iter() .map(|field| { match field_rename(builder, field) { Rename::Global(rename) => - FieldLit::Global( + FieldAttrs::Global( builder.expr().build_lit(P(rename.clone()))), Rename::Format(renames) => { let mut res = HashMap::new(); @@ -90,13 +84,13 @@ pub fn struct_field_strs( renames.into_iter() .map(|(k,v)| (k, builder.expr().build_lit(P(v.clone()))))); - FieldLit::Format{ + FieldAttrs::Format{ formats: res, default: default_field(cx, builder, field.node.kind), } }, Rename::None => { - FieldLit::Global( + FieldAttrs::Global( default_field(cx, builder, field.node.kind)) } } diff --git a/serde_macros/src/lib.rs b/serde_macros/src/lib.rs index 6af23dde..afc37a34 100644 --- a/serde_macros/src/lib.rs +++ b/serde_macros/src/lib.rs @@ -10,9 +10,10 @@ use syntax::ext::base::Decorator; use syntax::parse::token; use rustc::plugin::Registry; -mod ser; -mod de; +mod attr; mod field; +mod de; +mod ser; #[plugin_registrar] #[doc(hidden)] diff --git a/serde_macros/src/ser.rs b/serde_macros/src/ser.rs index 24c35806..c8579313 100644 --- a/serde_macros/src/ser.rs +++ b/serde_macros/src/ser.rs @@ -13,7 +13,7 @@ use syntax::ptr::P; use aster; -use field::{FieldLit, struct_field_strs}; +use field::struct_field_strs; pub fn expand_derive_serialize( cx: &mut ExtCtxt, @@ -523,23 +523,7 @@ fn serialize_struct_visitor( .zip(value_exprs) .enumerate() .map(|(i, (field, value_expr))| { - let key_expr = match field { - FieldLit::Global(x) => x, - FieldLit::Format{formats, default} => { - let arms = formats.iter() - .map(|(fmt, lit)| { - quote_arm!(cx, $fmt => { $lit }) - }) - .collect::>(); - quote_expr!(cx, - { - match S::format() { - $arms, - _ => $default - } - }) - }, - }; + let key_expr = field.serializer_key_expr(cx); quote_arm!(cx, $i => { self.state += 1; diff --git a/src/ser/mod.rs b/src/ser/mod.rs index 612ff059..86ae179f 100644 --- a/src/ser/mod.rs +++ b/src/ser/mod.rs @@ -104,7 +104,7 @@ pub trait Serializer { fn visit_str(&mut self, value: &str) -> Result<(), Self::Error>; /// `visit_bytes` is a hook that enables those serialization formats that support serializing - /// byte slices separately from generic arrays. By default it serializes as a regular array. + /// byte slices separately from generic arrays. By default it serializes as a regular array. #[inline] fn visit_bytes(&mut self, value: &[u8]) -> Result<(), Self::Error> { self.visit_seq(impls::SeqIteratorVisitor::new(value.iter(), Some(value.len()))) diff --git a/tests/test_annotations.rs b/tests/test_annotations.rs index 5f3d37b6..0333d352 100644 --- a/tests/test_annotations.rs +++ b/tests/test_annotations.rs @@ -27,6 +27,16 @@ struct FormatRename { a2: i32, } +#[derive(Debug, PartialEq, Deserialize, Serialize)] +enum SerEnum { + Map { + a: i8, + #[serde(rename(xml= "c", json="d"))] + b: A, + }, +} + + #[test] fn test_default() { let deserialized_value: Default = json::from_str(&"{\"a1\":1,\"a2\":2}").unwrap(); @@ -55,3 +65,15 @@ fn test_format_rename() { let deserialized_value = json::from_str("{\"a1\":1,\"a5\":2}").unwrap(); assert_eq!(value, deserialized_value); } + +#[test] +fn test_enum_format_rename() { + let s1 = String::new(); + let value = SerEnum::Map { a: 0i8, b: s1 }; + let serialized_value = json::to_string(&value).unwrap(); + let ans = "{\"Map\":{\"a\":0,\"d\":\"\"}}"; + assert_eq!(serialized_value, ans); + + let deserialized_value = json::from_str(ans).unwrap(); + assert_eq!(value, deserialized_value); +}