mirror of
https://github.com/launchbadge/sqlx.git
synced 2026-03-22 18:14:11 +00:00
fix: audit sqlx_postgres::types::rust_decimal for overflowing casts
This commit is contained in:
@@ -75,6 +75,14 @@ impl PgNumericSign {
|
||||
}
|
||||
|
||||
impl PgNumeric {
|
||||
/// Equivalent value of `0::numeric`.
|
||||
pub const ZERO: Self = PgNumeric::Number {
|
||||
sign: PgNumericSign::Positive,
|
||||
digits: vec![],
|
||||
weight: 0,
|
||||
scale: 0,
|
||||
};
|
||||
|
||||
pub(crate) fn decode(mut buf: &[u8]) -> Result<Self, BoxDynError> {
|
||||
// https://github.com/postgres/postgres/blob/bcd1c3630095e48bc3b1eb0fc8e8c8a7c851eba1/src/backend/utils/adt/numeric.c#L874
|
||||
let num_digits = buf.get_u16();
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use rust_decimal::{prelude::Zero, Decimal};
|
||||
use rust_decimal::Decimal;
|
||||
|
||||
use crate::decode::Decode;
|
||||
use crate::encode::{Encode, IsNull};
|
||||
@@ -25,9 +25,17 @@ impl TryFrom<PgNumeric> for Decimal {
|
||||
type Error = BoxDynError;
|
||||
|
||||
fn try_from(numeric: PgNumeric) -> Result<Self, BoxDynError> {
|
||||
let (digits, sign, mut weight, scale) = match numeric {
|
||||
Decimal::try_from(&numeric)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&'_ PgNumeric> for Decimal {
|
||||
type Error = BoxDynError;
|
||||
|
||||
fn try_from(numeric: &'_ PgNumeric) -> Result<Self, BoxDynError> {
|
||||
let (digits, sign, mut weight, scale) = match *numeric {
|
||||
PgNumeric::Number {
|
||||
digits,
|
||||
ref digits,
|
||||
sign,
|
||||
weight,
|
||||
scale,
|
||||
@@ -40,13 +48,13 @@ impl TryFrom<PgNumeric> for Decimal {
|
||||
|
||||
if digits.is_empty() {
|
||||
// Postgres returns an empty digit array for 0
|
||||
return Ok(0u64.into());
|
||||
return Ok(Decimal::ZERO);
|
||||
}
|
||||
|
||||
let mut value = Decimal::ZERO;
|
||||
|
||||
// Sum over `digits`, multiply each by its weight and add it to `value`.
|
||||
for digit in digits {
|
||||
for &digit in digits {
|
||||
let mul = Decimal::from(10_000i16)
|
||||
.checked_powi(weight as i64)
|
||||
.ok_or("value not representable as rust_decimal::Decimal")?;
|
||||
@@ -71,40 +79,40 @@ impl TryFrom<PgNumeric> for Decimal {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Decimal> for PgNumeric {
|
||||
fn from(value: Decimal) -> Self {
|
||||
PgNumeric::from(&value)
|
||||
}
|
||||
}
|
||||
|
||||
// This impl is effectively infallible because `NUMERIC` has a greater range than `Decimal`.
|
||||
impl From<&'_ Decimal> for PgNumeric {
|
||||
// Impl has been manually validated.
|
||||
#[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
|
||||
fn from(decimal: &Decimal) -> Self {
|
||||
// `Decimal` added `is_zero()` as an inherent method in a more recent version
|
||||
if Zero::is_zero(decimal) {
|
||||
PgNumeric::Number {
|
||||
sign: PgNumericSign::Positive,
|
||||
scale: 0,
|
||||
weight: 0,
|
||||
digits: vec![],
|
||||
};
|
||||
if Decimal::is_zero(decimal) {
|
||||
return PgNumeric::ZERO;
|
||||
}
|
||||
|
||||
assert!(
|
||||
(0u32..=28).contains(&decimal.scale()),
|
||||
"decimal scale out of range {:?}",
|
||||
decimal.unpack(),
|
||||
);
|
||||
|
||||
// Cannot overflow: always in the range [0, 28]
|
||||
let scale = decimal.scale() as u16;
|
||||
|
||||
// A serialized version of the decimal number. The resulting byte array
|
||||
// will have the following representation:
|
||||
//
|
||||
// Bytes 1-4: flags
|
||||
// Bytes 5-8: lo portion of m
|
||||
// Bytes 9-12: mid portion of m
|
||||
// Bytes 13-16: high portion of m
|
||||
let mut mantissa = u128::from_le_bytes(decimal.serialize());
|
||||
let mut mantissa = decimal.mantissa().unsigned_abs();
|
||||
|
||||
// chop off the flags
|
||||
mantissa >>= 32;
|
||||
|
||||
// If our scale is not a multiple of 4, we need to go to the next
|
||||
// multiple.
|
||||
// If our scale is not a multiple of 4, we need to go to the next multiple.
|
||||
let groups_diff = scale % 4;
|
||||
if groups_diff > 0 {
|
||||
let remainder = 4 - groups_diff as u32;
|
||||
let power = 10u32.pow(remainder) as u128;
|
||||
|
||||
// Impossible to overflow; 0 <= mantissa <= 2^96,
|
||||
// and we're multiplying by at most 1,000 (giving us a result < 2^106)
|
||||
mantissa *= power;
|
||||
}
|
||||
|
||||
@@ -113,16 +121,32 @@ impl From<&'_ Decimal> for PgNumeric {
|
||||
|
||||
// Convert to base-10000.
|
||||
while mantissa != 0 {
|
||||
// Cannot overflow or wrap because of the modulus
|
||||
digits.push((mantissa % 10_000) as i16);
|
||||
mantissa /= 10_000;
|
||||
}
|
||||
|
||||
// Change the endianness.
|
||||
// We started with the low digits first, but they should actually be at the end.
|
||||
digits.reverse();
|
||||
|
||||
// Weight is number of digits on the left side of the decimal.
|
||||
let digits_after_decimal = (scale + 3) / 4;
|
||||
let weight = digits.len() as i16 - digits_after_decimal as i16 - 1;
|
||||
// Cannot overflow: strictly smaller than `scale`.
|
||||
let digits_after_decimal = scale.div_ceil(4) as i16;
|
||||
|
||||
// `mantissa` contains at most 29 decimal digits (log10(2^96)),
|
||||
// split into at most 8 4-digit segments.
|
||||
assert!(
|
||||
digits.len() <= 8,
|
||||
"digits.len() out of range: {}; unpacked: {:?}",
|
||||
digits.len(),
|
||||
decimal.unpack()
|
||||
);
|
||||
|
||||
// Cannot overflow; at most 8
|
||||
let num_digits = digits.len() as i16;
|
||||
|
||||
// Find how many 4-digit segments should go before the decimal point.
|
||||
// `weight = 0` puts just `digit[0]` before the decimal point, and the rest after.
|
||||
let weight = num_digits - digits_after_decimal - 1;
|
||||
|
||||
// Remove non-significant zeroes.
|
||||
while let Some(&0) = digits.last() {
|
||||
@@ -134,6 +158,7 @@ impl From<&'_ Decimal> for PgNumeric {
|
||||
false => PgNumericSign::Positive,
|
||||
true => PgNumericSign::Negative,
|
||||
},
|
||||
// Cannot overflow; between 0 and 28
|
||||
scale: scale as i16,
|
||||
weight,
|
||||
digits,
|
||||
@@ -160,7 +185,7 @@ impl Decode<'_, Postgres> for Decimal {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod decimal_to_pgnumeric {
|
||||
mod tests {
|
||||
use super::{Decimal, PgNumeric, PgNumericSign};
|
||||
use std::convert::TryFrom;
|
||||
|
||||
@@ -169,13 +194,13 @@ mod decimal_to_pgnumeric {
|
||||
let zero: Decimal = "0".parse().unwrap();
|
||||
|
||||
assert_eq!(
|
||||
PgNumeric::try_from(&zero).unwrap(),
|
||||
PgNumeric::Number {
|
||||
sign: PgNumericSign::Positive,
|
||||
scale: 0,
|
||||
weight: 0,
|
||||
digits: vec![]
|
||||
}
|
||||
PgNumeric::from(&zero),
|
||||
PgNumeric::ZERO,
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
Decimal::try_from(&PgNumeric::ZERO).unwrap(),
|
||||
Decimal::ZERO
|
||||
);
|
||||
}
|
||||
|
||||
@@ -343,6 +368,48 @@ mod decimal_to_pgnumeric {
|
||||
assert_eq!(actual_decimal.scale(), 8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn max_value() {
|
||||
let expected_numeric = PgNumeric::Number {
|
||||
sign: PgNumericSign::Positive,
|
||||
scale: 0,
|
||||
weight: 7,
|
||||
digits: vec![7, 9228, 1625, 1426, 4337, 5935, 4395, 0335],
|
||||
};
|
||||
assert_eq!(
|
||||
PgNumeric::try_from(&Decimal::MAX).unwrap(),
|
||||
expected_numeric
|
||||
);
|
||||
|
||||
let actual_decimal = Decimal::try_from(expected_numeric).unwrap();
|
||||
assert_eq!(actual_decimal, Decimal::MAX);
|
||||
// Value split by 10,000's to match the expected digits[]
|
||||
assert_eq!(actual_decimal.mantissa(), 7_9228_1625_1426_4337_5935_4395_0335);
|
||||
assert_eq!(actual_decimal.scale(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn max_value_max_scale() {
|
||||
let mut max_value_max_scale = Decimal::MAX;
|
||||
max_value_max_scale.set_scale(28).unwrap();
|
||||
|
||||
let expected_numeric = PgNumeric::Number {
|
||||
sign: PgNumericSign::Positive,
|
||||
scale: 28,
|
||||
weight: 0,
|
||||
digits: vec![7, 9228, 1625, 1426, 4337, 5935, 4395, 0335],
|
||||
};
|
||||
assert_eq!(
|
||||
PgNumeric::try_from(&max_value_max_scale).unwrap(),
|
||||
expected_numeric
|
||||
);
|
||||
|
||||
let actual_decimal = Decimal::try_from(expected_numeric).unwrap();
|
||||
assert_eq!(actual_decimal, max_value_max_scale);
|
||||
assert_eq!(actual_decimal.mantissa(), 79_228_162_514_264_337_593_543_950_335);
|
||||
assert_eq!(actual_decimal.scale(), 28);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn issue_423_four_digit() {
|
||||
// This is a regression test for https://github.com/launchbadge/sqlx/issues/423
|
||||
@@ -420,7 +487,4 @@ mod decimal_to_pgnumeric {
|
||||
assert_eq!(actual_decimal.mantissa(), 10000);
|
||||
assert_eq!(actual_decimal.scale(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn issue_666_trailing_zeroes_at_max_precision() {}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user