From 83e3ec0a45fb61ef16f5160c36f21e989dfd838e Mon Sep 17 00:00:00 2001 From: David Tolnay Date: Thu, 27 Apr 2017 12:47:07 -0700 Subject: [PATCH] Allow integers to be used as map keys again --- src/de.rs | 98 ++++++++++++++++++++++++++++++++++++++++- src/read.rs | 12 +++++ src/value/de.rs | 77 +++++++++++++++++++++++++++++++- tests/test.rs | 114 ++++++++++++++++++++++++------------------------ 4 files changed, 240 insertions(+), 61 deletions(-) diff --git a/src/de.rs b/src/de.rs index b89cd36..5523a00 100644 --- a/src/de.rs +++ b/src/de.rs @@ -852,7 +852,7 @@ impl<'de, 'a, R: Read<'de> + 'a> de::MapAccess<'de> for MapAccess<'a, R> { }; match peek { - Some(b'"') => Ok(Some(try!(seed.deserialize(&mut *self.de)))), + Some(b'"') => seed.deserialize(MapKey { de: &mut *self.de }).map(Some), Some(_) => Err(self.de.peek_error(ErrorCode::KeyMustBeAString)), None => Err(self.de.peek_error(ErrorCode::EofWhileParsingValue)), } @@ -973,6 +973,102 @@ impl<'de, 'a, R: Read<'de> + 'a> de::VariantAccess<'de> for UnitVariantAccess<'a } } +/// Only deserialize from this after peeking a '"' byte! Otherwise it may +/// deserialize invalid JSON successfully. +struct MapKey<'a, R: 'a> { + de: &'a mut Deserializer, +} + +macro_rules! deserialize_integer_key { + ($deserialize:ident => $visit:ident) => { + fn $deserialize(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + self.de.eat_char(); + self.de.str_buf.clear(); + let string = try!(self.de.read.parse_str(&mut self.de.str_buf)); + let integer = try!(string.parse().map_err(de::Error::custom)); + visitor.$visit(integer) + } + } +} + +impl<'de, 'a, R> de::Deserializer<'de> for MapKey<'a, R> +where + R: Read<'de>, +{ + type Error = Error; + + #[inline] + fn deserialize_any(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + self.de.parse_value(visitor) + } + + deserialize_integer_key!(deserialize_i8 => visit_i8); + deserialize_integer_key!(deserialize_i16 => visit_i16); + deserialize_integer_key!(deserialize_i32 => visit_i32); + deserialize_integer_key!(deserialize_i64 => visit_i64); + deserialize_integer_key!(deserialize_u8 => visit_u8); + deserialize_integer_key!(deserialize_u16 => visit_u16); + deserialize_integer_key!(deserialize_u32 => visit_u32); + deserialize_integer_key!(deserialize_u64 => visit_u64); + + #[inline] + fn deserialize_option(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + // Map keys cannot be null. + visitor.visit_some(self) + } + + #[inline] + fn deserialize_newtype_struct(self, _name: &'static str, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + visitor.visit_newtype_struct(self) + } + + #[inline] + fn deserialize_enum( + self, + name: &'static str, + variants: &'static [&'static str], + visitor: V, + ) -> Result + where + V: de::Visitor<'de>, + { + self.de.deserialize_enum(name, variants, visitor) + } + + #[inline] + fn deserialize_bytes(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + self.de.deserialize_bytes(visitor) + } + + #[inline] + fn deserialize_byte_buf(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + self.de.deserialize_bytes(visitor) + } + + forward_to_deserialize_any! { + bool f32 f64 char str string unit unit_struct seq tuple tuple_struct map + struct identifier ignored_any + } +} + ////////////////////////////////////////////////////////////////////////////// /// Iterator that deserializes a stream into multiple JSON values. diff --git a/src/read.rs b/src/read.rs index d3e47a6..a3aa6bf 100644 --- a/src/read.rs +++ b/src/read.rs @@ -7,6 +7,7 @@ // except according to those terms. use std::{char, cmp, io, str}; +use std::ops::Deref; use iter::LineColIterator; @@ -82,6 +83,17 @@ pub enum Reference<'b, 'c, T: ?Sized + 'static> { Copied(&'c T), } +impl<'b, 'c, T: ?Sized + 'static> Deref for Reference<'b, 'c, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + match *self { + Reference::Borrowed(b) => b, + Reference::Copied(c) => c, + } + } +} + /// JSON input source that reads from a std::io input stream. pub struct IoRead where diff --git a/src/value/de.rs b/src/value/de.rs index 189b22e..776362d 100644 --- a/src/value/de.rs +++ b/src/value/de.rs @@ -1,3 +1,4 @@ +use std::borrow::Cow; use std::fmt; use std::i64; use std::io; @@ -461,7 +462,8 @@ impl<'de> MapAccess<'de> for MapDeserializer { match self.iter.next() { Some((key, value)) => { self.value = Some(value); - seed.deserialize(key.into_deserializer()).map(Some) + let key_de = MapKeyDeserializer { key: Cow::Owned(key) }; + seed.deserialize(key_de).map(Some) } None => Ok(None), } @@ -770,7 +772,8 @@ impl<'de> MapAccess<'de> for MapRefDeserializer<'de> { match self.iter.next() { Some((key, value)) => { self.value = Some(value); - seed.deserialize((&**key).into_deserializer()).map(Some) + let key_de = MapKeyDeserializer { key: Cow::Borrowed(&**key) }; + seed.deserialize(key_de).map(Some) } None => Ok(None), } @@ -812,6 +815,76 @@ impl<'de> serde::Deserializer<'de> for MapRefDeserializer<'de> { } } +struct MapKeyDeserializer<'de> { + key: Cow<'de, str>, +} + +macro_rules! deserialize_integer_key { + ($deserialize:ident => $visit:ident) => { + fn $deserialize(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + let integer = try!(self.key.parse().map_err(serde::de::Error::custom)); + visitor.$visit(integer) + } + } +} + +impl<'de> serde::Deserializer<'de> for MapKeyDeserializer<'de> { + type Error = Error; + + fn deserialize_any(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.key.into_deserializer().deserialize_any(visitor) + } + + deserialize_integer_key!(deserialize_i8 => visit_i8); + deserialize_integer_key!(deserialize_i16 => visit_i16); + deserialize_integer_key!(deserialize_i32 => visit_i32); + deserialize_integer_key!(deserialize_i64 => visit_i64); + deserialize_integer_key!(deserialize_u8 => visit_u8); + deserialize_integer_key!(deserialize_u16 => visit_u16); + deserialize_integer_key!(deserialize_u32 => visit_u32); + deserialize_integer_key!(deserialize_u64 => visit_u64); + + #[inline] + fn deserialize_option(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + // Map keys cannot be null. + visitor.visit_some(self) + } + + #[inline] + fn deserialize_newtype_struct(self, _name: &'static str, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_newtype_struct(self) + } + + fn deserialize_enum( + self, + name: &'static str, + variants: &'static [&'static str], + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + self.key.into_deserializer().deserialize_enum(name, variants, visitor) + } + + forward_to_deserialize_any! { + bool f32 f64 char str string bytes byte_buf unit unit_struct seq tuple + tuple_struct map struct identifier ignored_any + } +} + impl Value { fn unexpected(&self) -> Unexpected { match *self { diff --git a/tests/test.rs b/tests/test.rs index 97c2b80..a4bb342 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -568,33 +568,6 @@ fn test_write_newtype_struct() { test_encode_ok(&[(outer, r#"{"outer":{"inner":123}}"#)]); } -#[test] -fn test_write_map_with_integer_keys_issue_221() { - let mut map = BTreeMap::new(); - map.insert(0, "x"); // map with integer key - - assert_eq!( - serde_json::to_value(&map).unwrap(), - json!({"0": "x"}) - ); - - test_encode_ok(&[(&map, r#"{"0":"x"}"#)]); - - #[derive(Eq, PartialEq, Ord, PartialOrd)] - struct Float; - impl Serialize for Float { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - serializer.serialize_f32(1.0) - } - } - let mut map = BTreeMap::new(); - map.insert(Float, "x"); // map with float key - assert!(serde_json::to_value(&map).is_err()); -} - fn test_parse_ok(tests: Vec<(&str, T)>) where T: Clone + Debug + PartialEq + ser::Serialize + de::DeserializeOwned, @@ -1504,32 +1477,6 @@ fn test_serialize_rejects_adt_keys() { assert_eq!(err.to_string(), "key must be a string"); } -#[test] -fn test_effectively_string_keys() { - #[derive(Eq, PartialEq, Ord, PartialOrd, Debug, Serialize, Deserialize)] - enum Enum { - Zero, - One, - } - let map = treemap! { - Enum::Zero => 0, - Enum::One => 1 - }; - let expected = r#"{"Zero":0,"One":1}"#; - assert_eq!(to_string(&map).unwrap(), expected); - assert_eq!(map, from_str(expected).unwrap()); - - #[derive(Eq, PartialEq, Ord, PartialOrd, Debug, Serialize, Deserialize)] - struct Wrapper(String); - let map = treemap! { - Wrapper("zero".to_owned()) => 0, - Wrapper("one".to_owned()) => 1 - }; - let expected = r#"{"one":1,"zero":0}"#; - assert_eq!(to_string(&map).unwrap(), expected); - assert_eq!(map, from_str(expected).unwrap()); -} - #[test] fn test_bytes_ser() { let buf = vec![]; @@ -1675,15 +1622,66 @@ fn test_stack_overflow() { } #[test] -fn test_allow_ser_integers_as_map_keys() { +fn test_integer_key() { + // map with integer keys let map = treemap!( 1 => 2, - 2 => 4, - -1 => 6, - -2 => 8 + -1 => 6 ); + let j = r#"{"-1":6,"1":2}"#; + test_encode_ok(&[(&map, j)]); + test_parse_ok(vec![(j, map)]); - assert_eq!(to_string(&map).unwrap(), r#"{"-2":8,"-1":6,"1":2,"2":4}"#); + let j = r#"{"x":null}"#; + test_parse_err::>( + &[ + (j, "invalid digit found in string at line 1 column 4"), + ], + ); +} + +#[test] +fn test_deny_float_key() { + #[derive(Eq, PartialEq, Ord, PartialOrd)] + struct Float; + impl Serialize for Float { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_f32(1.0) + } + } + + // map with float key + let map = treemap!(Float => "x"); + assert!(serde_json::to_value(&map).is_err()); +} + +#[test] +fn test_effectively_string_keys() { + #[derive(Eq, PartialEq, Ord, PartialOrd, Debug, Clone, Serialize, Deserialize)] + enum Enum { + One, + Two, + } + let map = treemap! { + Enum::One => 1, + Enum::Two => 2 + }; + let expected = r#"{"One":1,"Two":2}"#; + test_encode_ok(&[(&map, expected)]); + test_parse_ok(vec![(expected, map)]); + + #[derive(Eq, PartialEq, Ord, PartialOrd, Debug, Clone, Serialize, Deserialize)] + struct Wrapper(String); + let map = treemap! { + Wrapper("zero".to_owned()) => 0, + Wrapper("one".to_owned()) => 1 + }; + let expected = r#"{"one":1,"zero":0}"#; + test_encode_ok(&[(&map, expected)]); + test_parse_ok(vec![(expected, map)]); } #[test]