fix: audit sqlx_postgres::types::rust_decimal for overflowing casts

This commit is contained in:
Austin Bonander
2024-08-15 03:29:32 -07:00
parent 16f8b1900d
commit 544fff54e2
2 changed files with 113 additions and 41 deletions

View File

@@ -75,6 +75,14 @@ impl PgNumericSign {
}
impl PgNumeric {
/// Equivalent value of `0::numeric`.
pub const ZERO: Self = PgNumeric::Number {
sign: PgNumericSign::Positive,
digits: vec![],
weight: 0,
scale: 0,
};
pub(crate) fn decode(mut buf: &[u8]) -> Result<Self, BoxDynError> {
// https://github.com/postgres/postgres/blob/bcd1c3630095e48bc3b1eb0fc8e8c8a7c851eba1/src/backend/utils/adt/numeric.c#L874
let num_digits = buf.get_u16();

View File

@@ -1,4 +1,4 @@
use rust_decimal::{prelude::Zero, Decimal};
use rust_decimal::Decimal;
use crate::decode::Decode;
use crate::encode::{Encode, IsNull};
@@ -25,9 +25,17 @@ impl TryFrom<PgNumeric> for Decimal {
type Error = BoxDynError;
fn try_from(numeric: PgNumeric) -> Result<Self, BoxDynError> {
let (digits, sign, mut weight, scale) = match numeric {
Decimal::try_from(&numeric)
}
}
impl TryFrom<&'_ PgNumeric> for Decimal {
type Error = BoxDynError;
fn try_from(numeric: &'_ PgNumeric) -> Result<Self, BoxDynError> {
let (digits, sign, mut weight, scale) = match *numeric {
PgNumeric::Number {
digits,
ref digits,
sign,
weight,
scale,
@@ -40,13 +48,13 @@ impl TryFrom<PgNumeric> for Decimal {
if digits.is_empty() {
// Postgres returns an empty digit array for 0
return Ok(0u64.into());
return Ok(Decimal::ZERO);
}
let mut value = Decimal::ZERO;
// Sum over `digits`, multiply each by its weight and add it to `value`.
for digit in digits {
for &digit in digits {
let mul = Decimal::from(10_000i16)
.checked_powi(weight as i64)
.ok_or("value not representable as rust_decimal::Decimal")?;
@@ -71,40 +79,40 @@ impl TryFrom<PgNumeric> for Decimal {
}
}
impl From<Decimal> for PgNumeric {
fn from(value: Decimal) -> Self {
PgNumeric::from(&value)
}
}
// This impl is effectively infallible because `NUMERIC` has a greater range than `Decimal`.
impl From<&'_ Decimal> for PgNumeric {
// Impl has been manually validated.
#[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
fn from(decimal: &Decimal) -> Self {
// `Decimal` added `is_zero()` as an inherent method in a more recent version
if Zero::is_zero(decimal) {
PgNumeric::Number {
sign: PgNumericSign::Positive,
scale: 0,
weight: 0,
digits: vec![],
};
if Decimal::is_zero(decimal) {
return PgNumeric::ZERO;
}
assert!(
(0u32..=28).contains(&decimal.scale()),
"decimal scale out of range {:?}",
decimal.unpack(),
);
// Cannot overflow: always in the range [0, 28]
let scale = decimal.scale() as u16;
// A serialized version of the decimal number. The resulting byte array
// will have the following representation:
//
// Bytes 1-4: flags
// Bytes 5-8: lo portion of m
// Bytes 9-12: mid portion of m
// Bytes 13-16: high portion of m
let mut mantissa = u128::from_le_bytes(decimal.serialize());
let mut mantissa = decimal.mantissa().unsigned_abs();
// chop off the flags
mantissa >>= 32;
// If our scale is not a multiple of 4, we need to go to the next
// multiple.
// If our scale is not a multiple of 4, we need to go to the next multiple.
let groups_diff = scale % 4;
if groups_diff > 0 {
let remainder = 4 - groups_diff as u32;
let power = 10u32.pow(remainder) as u128;
// Impossible to overflow; 0 <= mantissa <= 2^96,
// and we're multiplying by at most 1,000 (giving us a result < 2^106)
mantissa *= power;
}
@@ -113,16 +121,32 @@ impl From<&'_ Decimal> for PgNumeric {
// Convert to base-10000.
while mantissa != 0 {
// Cannot overflow or wrap because of the modulus
digits.push((mantissa % 10_000) as i16);
mantissa /= 10_000;
}
// Change the endianness.
// We started with the low digits first, but they should actually be at the end.
digits.reverse();
// Weight is number of digits on the left side of the decimal.
let digits_after_decimal = (scale + 3) / 4;
let weight = digits.len() as i16 - digits_after_decimal as i16 - 1;
// Cannot overflow: strictly smaller than `scale`.
let digits_after_decimal = scale.div_ceil(4) as i16;
// `mantissa` contains at most 29 decimal digits (log10(2^96)),
// split into at most 8 4-digit segments.
assert!(
digits.len() <= 8,
"digits.len() out of range: {}; unpacked: {:?}",
digits.len(),
decimal.unpack()
);
// Cannot overflow; at most 8
let num_digits = digits.len() as i16;
// Find how many 4-digit segments should go before the decimal point.
// `weight = 0` puts just `digit[0]` before the decimal point, and the rest after.
let weight = num_digits - digits_after_decimal - 1;
// Remove non-significant zeroes.
while let Some(&0) = digits.last() {
@@ -134,6 +158,7 @@ impl From<&'_ Decimal> for PgNumeric {
false => PgNumericSign::Positive,
true => PgNumericSign::Negative,
},
// Cannot overflow; between 0 and 28
scale: scale as i16,
weight,
digits,
@@ -160,7 +185,7 @@ impl Decode<'_, Postgres> for Decimal {
}
#[cfg(test)]
mod decimal_to_pgnumeric {
mod tests {
use super::{Decimal, PgNumeric, PgNumericSign};
use std::convert::TryFrom;
@@ -169,13 +194,13 @@ mod decimal_to_pgnumeric {
let zero: Decimal = "0".parse().unwrap();
assert_eq!(
PgNumeric::try_from(&zero).unwrap(),
PgNumeric::Number {
sign: PgNumericSign::Positive,
scale: 0,
weight: 0,
digits: vec![]
}
PgNumeric::from(&zero),
PgNumeric::ZERO,
);
assert_eq!(
Decimal::try_from(&PgNumeric::ZERO).unwrap(),
Decimal::ZERO
);
}
@@ -343,6 +368,48 @@ mod decimal_to_pgnumeric {
assert_eq!(actual_decimal.scale(), 8);
}
#[test]
fn max_value() {
let expected_numeric = PgNumeric::Number {
sign: PgNumericSign::Positive,
scale: 0,
weight: 7,
digits: vec![7, 9228, 1625, 1426, 4337, 5935, 4395, 0335],
};
assert_eq!(
PgNumeric::try_from(&Decimal::MAX).unwrap(),
expected_numeric
);
let actual_decimal = Decimal::try_from(expected_numeric).unwrap();
assert_eq!(actual_decimal, Decimal::MAX);
// Value split by 10,000's to match the expected digits[]
assert_eq!(actual_decimal.mantissa(), 7_9228_1625_1426_4337_5935_4395_0335);
assert_eq!(actual_decimal.scale(), 0);
}
#[test]
fn max_value_max_scale() {
let mut max_value_max_scale = Decimal::MAX;
max_value_max_scale.set_scale(28).unwrap();
let expected_numeric = PgNumeric::Number {
sign: PgNumericSign::Positive,
scale: 28,
weight: 0,
digits: vec![7, 9228, 1625, 1426, 4337, 5935, 4395, 0335],
};
assert_eq!(
PgNumeric::try_from(&max_value_max_scale).unwrap(),
expected_numeric
);
let actual_decimal = Decimal::try_from(expected_numeric).unwrap();
assert_eq!(actual_decimal, max_value_max_scale);
assert_eq!(actual_decimal.mantissa(), 79_228_162_514_264_337_593_543_950_335);
assert_eq!(actual_decimal.scale(), 28);
}
#[test]
fn issue_423_four_digit() {
// This is a regression test for https://github.com/launchbadge/sqlx/issues/423
@@ -420,7 +487,4 @@ mod decimal_to_pgnumeric {
assert_eq!(actual_decimal.mantissa(), 10000);
assert_eq!(actual_decimal.scale(), 2);
}
#[test]
fn issue_666_trailing_zeroes_at_max_precision() {}
}