fix: audit PgNumeric and usages for casts involving sign loss

This commit is contained in:
Austin Bonander
2024-08-16 14:06:38 -07:00
parent cac914fa21
commit 52c34a897a
3 changed files with 95 additions and 24 deletions

View File

@@ -1,7 +1,6 @@
use std::cmp;
use bigdecimal::BigDecimal;
use num_bigint::{BigInt, Sign};
use std::cmp;
use crate::decode::Decode;
use crate::encode::{Encode, IsNull};
@@ -26,9 +25,17 @@ impl TryFrom<PgNumeric> for BigDecimal {
type Error = BoxDynError;
fn try_from(numeric: PgNumeric) -> Result<Self, BoxDynError> {
let (digits, sign, weight) = match numeric {
Self::try_from(&numeric)
}
}
impl TryFrom<&'_ PgNumeric> for BigDecimal {
type Error = BoxDynError;
fn try_from(numeric: &'_ PgNumeric) -> Result<Self, Self::Error> {
let (digits, sign, weight) = match *numeric {
PgNumeric::Number {
digits,
ref digits,
sign,
weight,
..
@@ -50,11 +57,27 @@ impl TryFrom<PgNumeric> for BigDecimal {
};
// weight is 0 if the decimal point falls after the first base-10000 digit
//
// `Vec` capacity cannot exceed `isize::MAX` bytes, so this cast can't wrap in practice.
#[allow(clippy::cast_possible_wrap)]
let scale = (digits.len() as i64 - weight as i64 - 1) * 4;
// no optimized algorithm for base-10 so use base-100 for faster processing
let mut cents = Vec::with_capacity(digits.len() * 2);
for digit in &digits {
#[allow(
clippy::cast_possible_truncation,
clippy::cast_possible_wrap,
clippy::cast_sign_loss
)]
for (i, &digit) in digits.iter().enumerate() {
if !PgNumeric::is_valid_digit(digit) {
return Err(format!(
"PgNumeric to BigDecimal: {i}th digit is out of range {digit}"
)
.into());
}
cents.push((digit / 100) as u8);
cents.push((digit % 100) as u8);
}
@@ -79,9 +102,16 @@ impl TryFrom<&'_ BigDecimal> for PgNumeric {
// FIXME: is there a way to iterate over the digits to avoid the Vec allocation
let (sign, base_10) = integer.to_radix_be(10);
let base_10_len = i64::try_from(base_10.len()).map_err(|_| {
format!(
"BigDecimal base-10 length out of range for PgNumeric: {}",
base_10.len()
)
})?;
// weight is positive power of 10000
// exp is the negative power of 10
let weight_10 = base_10.len() as i64 - exp;
let weight_10 = base_10_len - exp;
// scale is only nonzero when we have fractional digits
// since `exp` is the _negative_ decimal exponent, it tells us
@@ -103,19 +133,34 @@ impl TryFrom<&'_ BigDecimal> for PgNumeric {
base_10.len() / 4
};
let offset = weight_10.rem_euclid(4) as usize;
// For efficiency, we want to process the base-10 digits in chunks of 4,
// but that means we need to deal with the non-divisible remainder first.
let offset = weight_10.rem_euclid(4);
// Do a checked conversion to the smallest integer,
// so we can widen arbitrarily without triggering lints.
let offset = u8::try_from(offset).unwrap_or_else(|_| {
panic!("BUG: `offset` should be in the range [0, 4) but is {offset}")
});
let mut digits = Vec::with_capacity(digits_len);
if let Some(first) = base_10.get(..offset) {
if let Some(first) = base_10.get(..offset as usize) {
if !first.is_empty() {
digits.push(base_10_to_10000(first));
}
} else if offset != 0 {
digits.push(base_10_to_10000(&base_10) * 10i16.pow((offset - base_10.len()) as u32));
// If we didn't hit the `if let Some` branch,
// then `base_10.len()` must strictly be smaller
#[allow(clippy::cast_possible_truncation)]
let power = (offset as usize - base_10.len()) as u32;
digits.push(base_10_to_10000(&base_10) * 10i16.pow(power));
}
if let Some(rest) = base_10.get(offset..) {
if let Some(rest) = base_10.get(offset as usize..) {
// `chunk.len()` is always between 1 and 4
#[allow(clippy::cast_possible_truncation)]
digits.extend(
rest.chunks(4)
.map(|chunk| base_10_to_10000(chunk) * 10i16.pow(4 - chunk.len() as u32)),
@@ -138,15 +183,13 @@ impl TryFrom<&'_ BigDecimal> for PgNumeric {
#[doc=include_str!("bigdecimal-range.md")]
impl Encode<'_, Postgres> for BigDecimal {
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
PgNumeric::try_from(self)?.encode(buf);
PgNumeric::try_from(self)?.encode(buf)?;
Ok(IsNull::No)
}
fn size_hint(&self) -> usize {
// BigDecimal::digits() gives us base-10 digits, so we divide by 4 to get base-10000 digits
// and since this is just a hint we just always round up
8 + (self.digits() / 4 + 1) as usize * 2
PgNumeric::size_hint(self.digits())
}
}

View File

@@ -1,4 +1,5 @@
use sqlx_core::bytes::Buf;
use std::num::Saturating;
use crate::error::BoxDynError;
use crate::PgArgumentBuffer;
@@ -83,6 +84,27 @@ impl PgNumeric {
scale: 0,
};
pub(crate) fn is_valid_digit(digit: i16) -> bool {
(0..10_000).contains(&digit)
}
pub(crate) fn size_hint(decimal_digits: u64) -> usize {
let mut size_hint = Saturating(decimal_digits);
// BigDecimal::digits() gives us base-10 digits, so we divide by 4 to get base-10000 digits
// and since this is just a hint we just always round up
size_hint /= 4;
size_hint += 1;
// Times two bytes for each base-10000 digit
size_hint *= 2;
// Plus `weight` and `scale`
size_hint += 8;
usize::try_from(size_hint.0).unwrap_or(usize::MAX)
}
pub(crate) fn decode(mut buf: &[u8]) -> Result<Self, BoxDynError> {
// https://github.com/postgres/postgres/blob/bcd1c3630095e48bc3b1eb0fc8e8c8a7c851eba1/src/backend/utils/adt/numeric.c#L874
let num_digits = buf.get_u16();
@@ -104,11 +126,11 @@ impl PgNumeric {
}
}
/// ### Panics
/// ### Errors
///
/// * If `digits.len()` overflows `i16`
/// * If any element in `digits` is greater than or equal to 10000
pub(crate) fn encode(&self, buf: &mut PgArgumentBuffer) {
pub(crate) fn encode(&self, buf: &mut PgArgumentBuffer) -> Result<(), String> {
match *self {
PgNumeric::Number {
ref digits,
@@ -116,18 +138,22 @@ impl PgNumeric {
scale,
weight,
} => {
let digits_len: i16 = digits
.len()
.try_into()
.expect("PgNumeric.digits.len() should not overflow i16");
let digits_len = i16::try_from(digits.len()).map_err(|_| {
format!(
"PgNumeric digits.len() ({}) should not overflow i16",
digits.len()
)
})?;
buf.extend(&digits_len.to_be_bytes());
buf.extend(&weight.to_be_bytes());
buf.extend(&(sign as i16).to_be_bytes());
buf.extend(&scale.to_be_bytes());
for digit in digits {
debug_assert!(*digit < 10000, "PgNumeric digits must be in base-10000");
for (i, &digit) in digits.iter().enumerate() {
if Self::is_valid_digit(digit) {
return Err(format!("{i}th PgNumeric digit out of range {digit}"));
}
buf.extend(&digit.to_be_bytes());
}
@@ -140,5 +166,7 @@ impl PgNumeric {
buf.extend(&0_i16.to_be_bytes());
}
}
Ok(())
}
}

View File

@@ -50,7 +50,7 @@ impl TryFrom<&'_ PgNumeric> for Decimal {
// Postgres returns an empty digit array for 0
return Ok(Decimal::ZERO);
}
let scale = u32::try_from(scale)
.map_err(|_| format!("invalid scale value for Pg NUMERIC: {scale}"))?;
@@ -171,7 +171,7 @@ impl From<&'_ Decimal> for PgNumeric {
impl Encode<'_, Postgres> for Decimal {
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
PgNumeric::from(self).encode(buf);
PgNumeric::from(self).encode(buf)?;
Ok(IsNull::No)
}