diff --git a/sqlx-postgres/src/types/numeric.rs b/sqlx-postgres/src/types/numeric.rs index b281de46f..641687291 100644 --- a/sqlx-postgres/src/types/numeric.rs +++ b/sqlx-postgres/src/types/numeric.rs @@ -75,6 +75,14 @@ impl PgNumericSign { } impl PgNumeric { + /// Equivalent value of `0::numeric`. + pub const ZERO: Self = PgNumeric::Number { + sign: PgNumericSign::Positive, + digits: vec![], + weight: 0, + scale: 0, + }; + pub(crate) fn decode(mut buf: &[u8]) -> Result { // https://github.com/postgres/postgres/blob/bcd1c3630095e48bc3b1eb0fc8e8c8a7c851eba1/src/backend/utils/adt/numeric.c#L874 let num_digits = buf.get_u16(); diff --git a/sqlx-postgres/src/types/rust_decimal.rs b/sqlx-postgres/src/types/rust_decimal.rs index fa66eb393..d94dfe34c 100644 --- a/sqlx-postgres/src/types/rust_decimal.rs +++ b/sqlx-postgres/src/types/rust_decimal.rs @@ -1,4 +1,4 @@ -use rust_decimal::{prelude::Zero, Decimal}; +use rust_decimal::Decimal; use crate::decode::Decode; use crate::encode::{Encode, IsNull}; @@ -25,9 +25,17 @@ impl TryFrom for Decimal { type Error = BoxDynError; fn try_from(numeric: PgNumeric) -> Result { - let (digits, sign, mut weight, scale) = match numeric { + Decimal::try_from(&numeric) + } +} + +impl TryFrom<&'_ PgNumeric> for Decimal { + type Error = BoxDynError; + + fn try_from(numeric: &'_ PgNumeric) -> Result { + let (digits, sign, mut weight, scale) = match *numeric { PgNumeric::Number { - digits, + ref digits, sign, weight, scale, @@ -40,13 +48,13 @@ impl TryFrom for Decimal { if digits.is_empty() { // Postgres returns an empty digit array for 0 - return Ok(0u64.into()); + return Ok(Decimal::ZERO); } let mut value = Decimal::ZERO; // Sum over `digits`, multiply each by its weight and add it to `value`. - for digit in digits { + for &digit in digits { let mul = Decimal::from(10_000i16) .checked_powi(weight as i64) .ok_or("value not representable as rust_decimal::Decimal")?; @@ -71,40 +79,40 @@ impl TryFrom for Decimal { } } +impl From for PgNumeric { + fn from(value: Decimal) -> Self { + PgNumeric::from(&value) + } +} + // This impl is effectively infallible because `NUMERIC` has a greater range than `Decimal`. impl From<&'_ Decimal> for PgNumeric { + // Impl has been manually validated. + #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)] fn from(decimal: &Decimal) -> Self { - // `Decimal` added `is_zero()` as an inherent method in a more recent version - if Zero::is_zero(decimal) { - PgNumeric::Number { - sign: PgNumericSign::Positive, - scale: 0, - weight: 0, - digits: vec![], - }; + if Decimal::is_zero(decimal) { + return PgNumeric::ZERO; } + assert!( + (0u32..=28).contains(&decimal.scale()), + "decimal scale out of range {:?}", + decimal.unpack(), + ); + + // Cannot overflow: always in the range [0, 28] let scale = decimal.scale() as u16; - // A serialized version of the decimal number. The resulting byte array - // will have the following representation: - // - // Bytes 1-4: flags - // Bytes 5-8: lo portion of m - // Bytes 9-12: mid portion of m - // Bytes 13-16: high portion of m - let mut mantissa = u128::from_le_bytes(decimal.serialize()); + let mut mantissa = decimal.mantissa().unsigned_abs(); - // chop off the flags - mantissa >>= 32; - - // If our scale is not a multiple of 4, we need to go to the next - // multiple. + // If our scale is not a multiple of 4, we need to go to the next multiple. let groups_diff = scale % 4; if groups_diff > 0 { let remainder = 4 - groups_diff as u32; let power = 10u32.pow(remainder) as u128; + // Impossible to overflow; 0 <= mantissa <= 2^96, + // and we're multiplying by at most 1,000 (giving us a result < 2^106) mantissa *= power; } @@ -113,16 +121,32 @@ impl From<&'_ Decimal> for PgNumeric { // Convert to base-10000. while mantissa != 0 { + // Cannot overflow or wrap because of the modulus digits.push((mantissa % 10_000) as i16); mantissa /= 10_000; } - // Change the endianness. + // We started with the low digits first, but they should actually be at the end. digits.reverse(); - // Weight is number of digits on the left side of the decimal. - let digits_after_decimal = (scale + 3) / 4; - let weight = digits.len() as i16 - digits_after_decimal as i16 - 1; + // Cannot overflow: strictly smaller than `scale`. + let digits_after_decimal = scale.div_ceil(4) as i16; + + // `mantissa` contains at most 29 decimal digits (log10(2^96)), + // split into at most 8 4-digit segments. + assert!( + digits.len() <= 8, + "digits.len() out of range: {}; unpacked: {:?}", + digits.len(), + decimal.unpack() + ); + + // Cannot overflow; at most 8 + let num_digits = digits.len() as i16; + + // Find how many 4-digit segments should go before the decimal point. + // `weight = 0` puts just `digit[0]` before the decimal point, and the rest after. + let weight = num_digits - digits_after_decimal - 1; // Remove non-significant zeroes. while let Some(&0) = digits.last() { @@ -134,6 +158,7 @@ impl From<&'_ Decimal> for PgNumeric { false => PgNumericSign::Positive, true => PgNumericSign::Negative, }, + // Cannot overflow; between 0 and 28 scale: scale as i16, weight, digits, @@ -160,7 +185,7 @@ impl Decode<'_, Postgres> for Decimal { } #[cfg(test)] -mod decimal_to_pgnumeric { +mod tests { use super::{Decimal, PgNumeric, PgNumericSign}; use std::convert::TryFrom; @@ -169,13 +194,13 @@ mod decimal_to_pgnumeric { let zero: Decimal = "0".parse().unwrap(); assert_eq!( - PgNumeric::try_from(&zero).unwrap(), - PgNumeric::Number { - sign: PgNumericSign::Positive, - scale: 0, - weight: 0, - digits: vec![] - } + PgNumeric::from(&zero), + PgNumeric::ZERO, + ); + + assert_eq!( + Decimal::try_from(&PgNumeric::ZERO).unwrap(), + Decimal::ZERO ); } @@ -343,6 +368,48 @@ mod decimal_to_pgnumeric { assert_eq!(actual_decimal.scale(), 8); } + #[test] + fn max_value() { + let expected_numeric = PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 0, + weight: 7, + digits: vec![7, 9228, 1625, 1426, 4337, 5935, 4395, 0335], + }; + assert_eq!( + PgNumeric::try_from(&Decimal::MAX).unwrap(), + expected_numeric + ); + + let actual_decimal = Decimal::try_from(expected_numeric).unwrap(); + assert_eq!(actual_decimal, Decimal::MAX); + // Value split by 10,000's to match the expected digits[] + assert_eq!(actual_decimal.mantissa(), 7_9228_1625_1426_4337_5935_4395_0335); + assert_eq!(actual_decimal.scale(), 0); + } + + #[test] + fn max_value_max_scale() { + let mut max_value_max_scale = Decimal::MAX; + max_value_max_scale.set_scale(28).unwrap(); + + let expected_numeric = PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 28, + weight: 0, + digits: vec![7, 9228, 1625, 1426, 4337, 5935, 4395, 0335], + }; + assert_eq!( + PgNumeric::try_from(&max_value_max_scale).unwrap(), + expected_numeric + ); + + let actual_decimal = Decimal::try_from(expected_numeric).unwrap(); + assert_eq!(actual_decimal, max_value_max_scale); + assert_eq!(actual_decimal.mantissa(), 79_228_162_514_264_337_593_543_950_335); + assert_eq!(actual_decimal.scale(), 28); + } + #[test] fn issue_423_four_digit() { // This is a regression test for https://github.com/launchbadge/sqlx/issues/423 @@ -420,7 +487,4 @@ mod decimal_to_pgnumeric { assert_eq!(actual_decimal.mantissa(), 10000); assert_eq!(actual_decimal.scale(), 2); } - - #[test] - fn issue_666_trailing_zeroes_at_max_precision() {} }