refactor PgNumeric so NaN can't be misinterpreted, document types

This commit is contained in:
Austin Bonander 2020-03-18 19:04:11 -07:00
parent 63f5592ecf
commit 3a43e939e3
5 changed files with 161 additions and 83 deletions

View File

@ -36,6 +36,8 @@ impl TryFrom<&'_ BigDecimal> for PgNumeric {
let weight_10 = base_10.len() as i64 - exp;
// scale is only nonzero when we have fractional digits
// since `exp` is the _negative_ decimal exponent, it tells us
// exactly what our scale should be
let scale: i16 = cmp::max(0, exp).try_into()?;
// there's an implicit +1 offset in the interpretation
@ -77,7 +79,7 @@ impl TryFrom<&'_ BigDecimal> for PgNumeric {
digits.pop();
}
Ok(PgNumeric {
Ok(PgNumeric::Number {
sign: match sign {
Sign::Plus | Sign::NoSign => PgNumericSign::Positive,
Sign::Minus => PgNumericSign::Negative,
@ -93,24 +95,33 @@ impl TryFrom<PgNumeric> for BigDecimal {
type Error = crate::Error;
fn try_from(numeric: PgNumeric) -> crate::Result<Self> {
let sign = match numeric.sign {
_ if numeric.digits.is_empty() => Sign::NoSign,
PgNumericSign::Positive => Sign::Plus,
PgNumericSign::Negative => Sign::Minus,
PgNumericSign::NotANumber => {
let (digits, sign, weight) = match numeric {
PgNumeric::Number {
digits,
sign,
weight,
..
} => (digits, sign, weight),
PgNumeric::NotANumber => {
return Err(crate::Error::Decode(
"BigDecimal does not support NaN values".into(),
))
}
};
let sign = match sign {
_ if digits.is_empty() => Sign::NoSign,
PgNumericSign::Positive => Sign::Plus,
PgNumericSign::Negative => Sign::Minus,
};
// `scale` is effectively the number of places left to shift the decimal point
// weight is 0 if the decimal point falls after the first base-10000 digit
let scale = (numeric.digits.len() as i64 - numeric.weight as i64 - 1) * 4;
let scale = (digits.len() as i64 - weight as i64 - 1) * 4;
// no optimized algorithm for base-10 so use base-100 for faster processing
let mut cents = Vec::with_capacity(numeric.digits.len() * 2);
for digit in &numeric.digits {
let mut cents = Vec::with_capacity(digits.len() * 2);
for digit in &digits {
cents.push((digit / 100) as u8);
cents.push((digit % 100) as u8);
}
@ -118,8 +129,6 @@ impl TryFrom<PgNumeric> for BigDecimal {
let bigint = BigInt::from_radix_be(sign, &cents, 100)
.expect("BUG digit outside of given radix, check math above");
dbg!(&numeric);
Ok(BigDecimal::new(bigint, scale))
}
}
@ -156,7 +165,7 @@ fn test_bigdecimal_to_pgnumeric() {
let one: BigDecimal = "1".parse().unwrap();
assert_eq!(
PgNumeric::try_from(&one).unwrap(),
PgNumeric {
PgNumeric::Number {
sign: PgNumericSign::Positive,
scale: 0,
weight: 0,
@ -167,7 +176,7 @@ fn test_bigdecimal_to_pgnumeric() {
let ten: BigDecimal = "10".parse().unwrap();
assert_eq!(
PgNumeric::try_from(&ten).unwrap(),
PgNumeric {
PgNumeric::Number {
sign: PgNumericSign::Positive,
scale: 0,
weight: 0,
@ -178,7 +187,7 @@ fn test_bigdecimal_to_pgnumeric() {
let one_hundred: BigDecimal = "100".parse().unwrap();
assert_eq!(
PgNumeric::try_from(&one_hundred).unwrap(),
PgNumeric {
PgNumeric::Number {
sign: PgNumericSign::Positive,
scale: 0,
weight: 0,
@ -190,7 +199,7 @@ fn test_bigdecimal_to_pgnumeric() {
let ten_thousand: BigDecimal = "10000".parse().unwrap();
assert_eq!(
PgNumeric::try_from(&ten_thousand).unwrap(),
PgNumeric {
PgNumeric::Number {
sign: PgNumericSign::Positive,
scale: 0,
weight: 1,
@ -201,7 +210,7 @@ fn test_bigdecimal_to_pgnumeric() {
let two_digits: BigDecimal = "12345".parse().unwrap();
assert_eq!(
PgNumeric::try_from(&two_digits).unwrap(),
PgNumeric {
PgNumeric::Number {
sign: PgNumericSign::Positive,
scale: 0,
weight: 1,
@ -212,7 +221,7 @@ fn test_bigdecimal_to_pgnumeric() {
let one_tenth: BigDecimal = "0.1".parse().unwrap();
assert_eq!(
PgNumeric::try_from(&one_tenth).unwrap(),
PgNumeric {
PgNumeric::Number {
sign: PgNumericSign::Positive,
scale: 1,
weight: -1,
@ -223,7 +232,7 @@ fn test_bigdecimal_to_pgnumeric() {
let decimal: BigDecimal = "1.2345".parse().unwrap();
assert_eq!(
PgNumeric::try_from(&decimal).unwrap(),
PgNumeric {
PgNumeric::Number {
sign: PgNumericSign::Positive,
scale: 4,
weight: 0,
@ -234,7 +243,7 @@ fn test_bigdecimal_to_pgnumeric() {
let decimal: BigDecimal = "0.12345".parse().unwrap();
assert_eq!(
PgNumeric::try_from(&decimal).unwrap(),
PgNumeric {
PgNumeric::Number {
sign: PgNumericSign::Positive,
scale: 5,
weight: -1,
@ -245,7 +254,7 @@ fn test_bigdecimal_to_pgnumeric() {
let decimal: BigDecimal = "0.01234".parse().unwrap();
assert_eq!(
PgNumeric::try_from(&decimal).unwrap(),
PgNumeric {
PgNumeric::Number {
sign: PgNumericSign::Positive,
scale: 5,
weight: -1,
@ -256,7 +265,7 @@ fn test_bigdecimal_to_pgnumeric() {
let decimal: BigDecimal = "12345.67890".parse().unwrap();
assert_eq!(
PgNumeric::try_from(&decimal).unwrap(),
PgNumeric {
PgNumeric::Number {
sign: PgNumericSign::Positive,
scale: 5,
weight: 1,
@ -267,7 +276,7 @@ fn test_bigdecimal_to_pgnumeric() {
let one_digit_decimal: BigDecimal = "0.00001234".parse().unwrap();
assert_eq!(
PgNumeric::try_from(&one_digit_decimal).unwrap(),
PgNumeric {
PgNumeric::Number {
sign: PgNumericSign::Positive,
scale: 8,
weight: -2,

View File

@ -10,32 +10,68 @@ use crate::postgres::{PgTypeInfo, PgValue, Postgres};
use crate::types::Type;
use crate::Error;
/// Wire representation of a Postgres NUMERIC type
/// Represents a `NUMERIC` value in the **Postgres** wire protocol.
#[derive(Debug, PartialEq, Eq)]
pub struct PgNumeric {
pub sign: PgNumericSign,
pub scale: i16,
pub weight: i16,
pub digits: Vec<i16>,
pub enum PgNumeric {
/// Equivalent to the `'NaN'` value in Postgres. The result of, e.g. `1 / 0`.
NotANumber,
/// A populated `NUMERIC` value.
///
/// A description of these fields can be found here (although the type being described is the
/// version for in-memory calculations, the field names are the same):
/// https://github.com/postgres/postgres/blob/bcd1c3630095e48bc3b1eb0fc8e8c8a7c851eba1/src/backend/utils/adt/numeric.c#L224-L269
Number {
/// The sign of the value: positive (also set for 0 and -0), or negative.
sign: PgNumericSign,
/// The digits of the number in base-10000 with the most significant digit first
/// (big-endian).
///
/// The length of this vector must not overflow `i16` for the binary protocol.
///
/// *Note*: the `Encode` implementation will panic if any digit is `>= 10000`.
digits: Vec<i16>,
/// The scaling factor of the number, such that the value will be interpreted as
///
/// ```text
/// digits[0] * 10,000 ^ weight
/// + digits[1] * 10,000 ^ (weight - 1)
/// ...
/// + digits[N] * 10,000 ^ (weight - N) where N = digits.len() - 1
/// ```
/// May be negative.
weight: i16,
/// How many _decimal_ (base-10) digits following the decimal point to consider in
/// arithmetic regardless of how many actually follow the decimal point as determined by
/// `weight`--the comment in the Postgres code linked above recommends using this only for
/// ignoring unnecessary trailing zeroes (as trimming nonzero digits means reducing the
/// precision of the value).
///
/// Must be `>= 0`.
scale: i16,
},
}
// https://github.com/postgres/postgres/blob/bcd1c3630095e48bc3b1eb0fc8e8c8a7c851eba1/src/backend/utils/adt/numeric.c#L167-L170
const SIGN_POS: u16 = 0x0000;
const SIGN_NEG: u16 = 0x4000;
const SIGN_NAN: u16 = 0xC000; // overflows i16 (C equivalent truncates from integer literal)
/// Possible sign values for [PgNumeric].
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[repr(i16)]
#[repr(u16)]
pub enum PgNumericSign {
Positive = 0x0000,
Negative = 0x4000,
NotANumber = -0x4000, // 0xC000
Positive = SIGN_POS,
Negative = SIGN_NEG,
}
impl PgNumericSign {
fn from_u16(sign: i16) -> crate::Result<PgNumericSign> {
// https://github.com/postgres/postgres/blob/bcd1c3630095e48bc3b1eb0fc8e8c8a7c851eba1/src/backend/utils/adt/numeric.c#L167-L170
match sign {
0x0000 => Ok(PgNumericSign::Positive),
0x4000 => Ok(PgNumericSign::Negative),
-0x4000 => Ok(PgNumericSign::NotANumber),
fn try_from_u16(val: u16) -> crate::Result<Self> {
match val {
SIGN_POS => Ok(PgNumericSign::Positive),
SIGN_NEG => Ok(PgNumericSign::Negative),
SIGN_NAN => panic!("BUG: sign value for NaN passed to PgNumericSign"),
_ => Err(Error::Decode(
format!("unknown value for PgNumericSign: {:#04X}", sign).into(),
format!("invalid value for PgNumericSign: {:#04X}", val).into(),
)),
}
}
@ -46,30 +82,32 @@ impl Type<Postgres> for PgNumeric {
PgTypeInfo::new(TypeId::NUMERIC, "NUMERIC")
}
}
impl PgNumeric {
pub(crate) fn from_bytes(mut bytes: &[u8]) -> crate::Result<Self> {
// https://github.com/postgres/postgres/blob/bcd1c3630095e48bc3b1eb0fc8e8c8a7c851eba1/src/backend/utils/adt/numeric.c#L874
let num_digits = bytes.get_u16::<BigEndian>()?;
let weight = bytes.get_i16::<BigEndian>()?;
let sign = bytes.get_i16::<BigEndian>()?;
let sign = bytes.get_u16::<BigEndian>()?;
let scale = bytes.get_i16::<BigEndian>()?;
let digits: Vec<_> = (0..num_digits)
.map(|_| bytes.get_i16::<BigEndian>())
.collect::<Result<_, _>>()?;
Ok(PgNumeric {
sign: PgNumericSign::from_u16(sign)?,
scale,
weight,
digits,
})
if sign == SIGN_NAN {
Ok(PgNumeric::NotANumber)
} else {
let digits: Vec<_> = (0..num_digits)
.map(|_| bytes.get_i16::<BigEndian>())
.collect::<Result<_, _>>()?;
Ok(PgNumeric::Number {
sign: PgNumericSign::try_from_u16(sign)?,
scale,
weight,
digits,
})
}
}
}
/// ### Note
/// Receiving `PgNumeric` is only supported for the Postgres binary (prepared statements) protocol.
///
/// Receiving `PgNumeric` is currently only supported for the Postgres
/// binary (prepared statements) protocol.
impl Decode<'_, Postgres> for PgNumeric {
fn decode(value: Option<PgValue>) -> crate::Result<Self> {
if let PgValue::Binary(bytes) = value.try_into()? {
@ -81,32 +119,47 @@ impl Decode<'_, Postgres> for PgNumeric {
}
}
}
/// ### Panics
///
/// * If `self.digits.len()` overflows `i16`
/// * If any element in `self.digits` is greater than or equal to 10000
/// * If `digits.len()` overflows `i16`
/// * If any element in `digits` is greater than or equal to 10000
impl Encode<Postgres> for PgNumeric {
fn encode(&self, buf: &mut Vec<u8>) {
let digits_len: i16 = self
.digits
.len()
.try_into()
.expect("PgNumeric.digits.len() should not overflow i16");
match *self {
PgNumeric::Number {
ref digits,
sign,
scale,
weight,
} => {
let digits_len: i16 = digits
.len()
.try_into()
.expect("PgNumeric.digits.len() should not overflow i16");
buf.put_i16::<BigEndian>(digits_len);
buf.put_i16::<BigEndian>(self.weight);
buf.put_i16::<BigEndian>(self.sign as i16);
buf.put_i16::<BigEndian>(self.scale);
for &digit in &self.digits {
assert!(digit < 10000, "PgNumeric digits must be in base-10000");
buf.put_i16::<BigEndian>(digit);
buf.put_i16::<BigEndian>(digits_len);
buf.put_i16::<BigEndian>(weight);
buf.put_i16::<BigEndian>(sign as i16);
buf.put_i16::<BigEndian>(scale);
for &digit in digits {
assert!(digit < 10000, "PgNumeric digits must be in base-10000");
buf.put_i16::<BigEndian>(digit);
}
}
PgNumeric::NotANumber => {
buf.put_i16::<BigEndian>(0);
buf.put_i16::<BigEndian>(0);
buf.put_u16::<BigEndian>(SIGN_NAN);
buf.put_i16::<BigEndian>(0);
}
}
}
fn size_hint(&self) -> usize {
// 4 i16's plus digits
8 + self.digits.len() * 2
8 + if let PgNumeric::Number { digits, .. } = self {
digits.len() * 2
} else {
0
}
}
}

View File

@ -2,6 +2,7 @@
name = "sqlx-test"
version = "0.1.0"
edition = "2018"
publish = false
[dependencies]
sqlx = { default-features = false, path = ".." }

View File

@ -19,7 +19,7 @@ where
// Test type encoding and decoding
#[macro_export]
macro_rules! test_type {
($name:ident($db:ident, $ty:ty, $($text:literal == $value:expr),+)) => {
($name:ident($db:ident, $ty:ty, $($text:literal == $value:expr),+ $(,)?)) => {
$crate::test_prepared_type!($name($db, $ty, $($text == $value),+));
$crate::test_unprepared_type!($name($db, $ty, $($text == $value),+));
}
@ -28,7 +28,7 @@ macro_rules! test_type {
// Test type decoding for the simple (unprepared) query API
#[macro_export]
macro_rules! test_unprepared_type {
($name:ident($db:ident, $ty:ty, $($text:literal == $value:expr),+)) => {
($name:ident($db:ident, $ty:ty, $($text:literal == $value:expr),+ $(,)?)) => {
paste::item! {
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
@ -55,7 +55,7 @@ macro_rules! test_unprepared_type {
// Test type encoding and decoding for the prepared query API
#[macro_export]
macro_rules! test_prepared_type {
($name:ident($db:ident, $ty:ty, $($text:literal == $value:expr),+)) => {
($name:ident($db:ident, $ty:ty, $($text:literal == $value:expr),+ $(,)?)) => {
paste::item! {
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
#[cfg_attr(feature = "runtime-tokio", tokio::test)]

View File

@ -53,48 +53,63 @@ test_type!(bytea(
test_prepared_type!(numeric(
Postgres,
PgNumeric,
"0::numeric"
== PgNumeric::Number {
sign: PgNumericSign::Positive,
weight: 0,
scale: 0,
digits: vec![]
},
"(-0)::numeric"
== PgNumeric::Number {
sign: PgNumericSign::Positive,
weight: 0,
scale: 0,
digits: vec![]
},
"1::numeric"
== PgNumeric {
== PgNumeric::Number {
sign: PgNumericSign::Positive,
weight: 0,
scale: 0,
digits: vec![1]
},
"1234::numeric"
== PgNumeric {
== PgNumeric::Number {
sign: PgNumericSign::Positive,
weight: 0,
scale: 0,
digits: vec![1234]
},
"10000::numeric"
== PgNumeric {
== PgNumeric::Number {
sign: PgNumericSign::Positive,
weight: 1,
scale: 0,
digits: vec![1]
},
"0.1::numeric"
== PgNumeric {
== PgNumeric::Number {
sign: PgNumericSign::Positive,
weight: -1,
scale: 1,
digits: vec![1000]
},
"0.01234::numeric"
== PgNumeric {
== PgNumeric::Number {
sign: PgNumericSign::Positive,
weight: -1,
scale: 5,
digits: vec![123, 4000]
},
"12.34::numeric"
== PgNumeric {
== PgNumeric::Number {
sign: PgNumericSign::Positive,
weight: 0,
scale: 2,
digits: vec![12, 3400]
}
},
"'NaN'::numeric" == PgNumeric::NotANumber,
));
#[cfg(feature = "bigdecimal")]
@ -106,7 +121,7 @@ test_type!(decimal(
"0.1::numeric" == "0.1".parse::<sqlx::types::BigDecimal>().unwrap(),
"0.01234::numeric" == "0.01234".parse::<sqlx::types::BigDecimal>().unwrap(),
"12.34::numeric" == "12.34".parse::<sqlx::types::BigDecimal>().unwrap(),
"12345.6789::numeric" == "12345.6789".parse::<sqlx::types::BigDecimal>().unwrap()
"12345.6789::numeric" == "12345.6789".parse::<sqlx::types::BigDecimal>().unwrap(),
));
#[cfg(feature = "uuid")]