diff --git a/sqlx-postgres/src/types/bigdecimal.rs b/sqlx-postgres/src/types/bigdecimal.rs index 5a6e500d3..869f85079 100644 --- a/sqlx-postgres/src/types/bigdecimal.rs +++ b/sqlx-postgres/src/types/bigdecimal.rs @@ -1,7 +1,6 @@ -use std::cmp; - use bigdecimal::BigDecimal; use num_bigint::{BigInt, Sign}; +use std::cmp; use crate::decode::Decode; use crate::encode::{Encode, IsNull}; @@ -26,9 +25,17 @@ impl TryFrom for BigDecimal { type Error = BoxDynError; fn try_from(numeric: PgNumeric) -> Result { - let (digits, sign, weight) = match numeric { + Self::try_from(&numeric) + } +} + +impl TryFrom<&'_ PgNumeric> for BigDecimal { + type Error = BoxDynError; + + fn try_from(numeric: &'_ PgNumeric) -> Result { + let (digits, sign, weight) = match *numeric { PgNumeric::Number { - digits, + ref digits, sign, weight, .. @@ -50,11 +57,27 @@ impl TryFrom for BigDecimal { }; // weight is 0 if the decimal point falls after the first base-10000 digit + // + // `Vec` capacity cannot exceed `isize::MAX` bytes, so this cast can't wrap in practice. + #[allow(clippy::cast_possible_wrap)] let scale = (digits.len() as i64 - weight as i64 - 1) * 4; // no optimized algorithm for base-10 so use base-100 for faster processing let mut cents = Vec::with_capacity(digits.len() * 2); - for digit in &digits { + + #[allow( + clippy::cast_possible_truncation, + clippy::cast_possible_wrap, + clippy::cast_sign_loss + )] + for (i, &digit) in digits.iter().enumerate() { + if !PgNumeric::is_valid_digit(digit) { + return Err(format!( + "PgNumeric to BigDecimal: {i}th digit is out of range {digit}" + ) + .into()); + } + cents.push((digit / 100) as u8); cents.push((digit % 100) as u8); } @@ -79,9 +102,16 @@ impl TryFrom<&'_ BigDecimal> for PgNumeric { // FIXME: is there a way to iterate over the digits to avoid the Vec allocation let (sign, base_10) = integer.to_radix_be(10); + let base_10_len = i64::try_from(base_10.len()).map_err(|_| { + format!( + "BigDecimal base-10 length out of range for PgNumeric: {}", + base_10.len() + ) + })?; + // weight is positive power of 10000 // exp is the negative power of 10 - let weight_10 = base_10.len() as i64 - exp; + let weight_10 = base_10_len - exp; // scale is only nonzero when we have fractional digits // since `exp` is the _negative_ decimal exponent, it tells us @@ -103,19 +133,34 @@ impl TryFrom<&'_ BigDecimal> for PgNumeric { base_10.len() / 4 }; - let offset = weight_10.rem_euclid(4) as usize; + // For efficiency, we want to process the base-10 digits in chunks of 4, + // but that means we need to deal with the non-divisible remainder first. + let offset = weight_10.rem_euclid(4); + + // Do a checked conversion to the smallest integer, + // so we can widen arbitrarily without triggering lints. + let offset = u8::try_from(offset).unwrap_or_else(|_| { + panic!("BUG: `offset` should be in the range [0, 4) but is {offset}") + }); let mut digits = Vec::with_capacity(digits_len); - if let Some(first) = base_10.get(..offset) { + if let Some(first) = base_10.get(..offset as usize) { if !first.is_empty() { digits.push(base_10_to_10000(first)); } } else if offset != 0 { - digits.push(base_10_to_10000(&base_10) * 10i16.pow((offset - base_10.len()) as u32)); + // If we didn't hit the `if let Some` branch, + // then `base_10.len()` must strictly be smaller + #[allow(clippy::cast_possible_truncation)] + let power = (offset as usize - base_10.len()) as u32; + + digits.push(base_10_to_10000(&base_10) * 10i16.pow(power)); } - if let Some(rest) = base_10.get(offset..) { + if let Some(rest) = base_10.get(offset as usize..) { + // `chunk.len()` is always between 1 and 4 + #[allow(clippy::cast_possible_truncation)] digits.extend( rest.chunks(4) .map(|chunk| base_10_to_10000(chunk) * 10i16.pow(4 - chunk.len() as u32)), @@ -138,15 +183,13 @@ impl TryFrom<&'_ BigDecimal> for PgNumeric { #[doc=include_str!("bigdecimal-range.md")] impl Encode<'_, Postgres> for BigDecimal { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { - PgNumeric::try_from(self)?.encode(buf); + PgNumeric::try_from(self)?.encode(buf)?; Ok(IsNull::No) } fn size_hint(&self) -> usize { - // BigDecimal::digits() gives us base-10 digits, so we divide by 4 to get base-10000 digits - // and since this is just a hint we just always round up - 8 + (self.digits() / 4 + 1) as usize * 2 + PgNumeric::size_hint(self.digits()) } } diff --git a/sqlx-postgres/src/types/numeric.rs b/sqlx-postgres/src/types/numeric.rs index 641687291..3a01f2e62 100644 --- a/sqlx-postgres/src/types/numeric.rs +++ b/sqlx-postgres/src/types/numeric.rs @@ -1,4 +1,5 @@ use sqlx_core::bytes::Buf; +use std::num::Saturating; use crate::error::BoxDynError; use crate::PgArgumentBuffer; @@ -83,6 +84,27 @@ impl PgNumeric { scale: 0, }; + pub(crate) fn is_valid_digit(digit: i16) -> bool { + (0..10_000).contains(&digit) + } + + pub(crate) fn size_hint(decimal_digits: u64) -> usize { + let mut size_hint = Saturating(decimal_digits); + + // BigDecimal::digits() gives us base-10 digits, so we divide by 4 to get base-10000 digits + // and since this is just a hint we just always round up + size_hint /= 4; + size_hint += 1; + + // Times two bytes for each base-10000 digit + size_hint *= 2; + + // Plus `weight` and `scale` + size_hint += 8; + + usize::try_from(size_hint.0).unwrap_or(usize::MAX) + } + pub(crate) fn decode(mut buf: &[u8]) -> Result { // https://github.com/postgres/postgres/blob/bcd1c3630095e48bc3b1eb0fc8e8c8a7c851eba1/src/backend/utils/adt/numeric.c#L874 let num_digits = buf.get_u16(); @@ -104,11 +126,11 @@ impl PgNumeric { } } - /// ### Panics + /// ### Errors /// /// * If `digits.len()` overflows `i16` /// * If any element in `digits` is greater than or equal to 10000 - pub(crate) fn encode(&self, buf: &mut PgArgumentBuffer) { + pub(crate) fn encode(&self, buf: &mut PgArgumentBuffer) -> Result<(), String> { match *self { PgNumeric::Number { ref digits, @@ -116,18 +138,22 @@ impl PgNumeric { scale, weight, } => { - let digits_len: i16 = digits - .len() - .try_into() - .expect("PgNumeric.digits.len() should not overflow i16"); + let digits_len = i16::try_from(digits.len()).map_err(|_| { + format!( + "PgNumeric digits.len() ({}) should not overflow i16", + digits.len() + ) + })?; buf.extend(&digits_len.to_be_bytes()); buf.extend(&weight.to_be_bytes()); buf.extend(&(sign as i16).to_be_bytes()); buf.extend(&scale.to_be_bytes()); - for digit in digits { - debug_assert!(*digit < 10000, "PgNumeric digits must be in base-10000"); + for (i, &digit) in digits.iter().enumerate() { + if Self::is_valid_digit(digit) { + return Err(format!("{i}th PgNumeric digit out of range {digit}")); + } buf.extend(&digit.to_be_bytes()); } @@ -140,5 +166,7 @@ impl PgNumeric { buf.extend(&0_i16.to_be_bytes()); } } + + Ok(()) } } diff --git a/sqlx-postgres/src/types/rust_decimal.rs b/sqlx-postgres/src/types/rust_decimal.rs index 281bc7e46..8321e8281 100644 --- a/sqlx-postgres/src/types/rust_decimal.rs +++ b/sqlx-postgres/src/types/rust_decimal.rs @@ -50,7 +50,7 @@ impl TryFrom<&'_ PgNumeric> for Decimal { // Postgres returns an empty digit array for 0 return Ok(Decimal::ZERO); } - + let scale = u32::try_from(scale) .map_err(|_| format!("invalid scale value for Pg NUMERIC: {scale}"))?; @@ -171,7 +171,7 @@ impl From<&'_ Decimal> for PgNumeric { impl Encode<'_, Postgres> for Decimal { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { - PgNumeric::from(self).encode(buf); + PgNumeric::from(self).encode(buf)?; Ok(IsNull::No) }