fix(pg_money): handle negative values correctly in PgMoney::from_decimal() (#1334)

closes #1321
This commit is contained in:
Austin Bonander 2021-07-21 16:29:20 -07:00 committed by GitHub
parent 531740550f
commit a8544fd503
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -6,6 +6,7 @@ use crate::{
types::Type,
};
use byteorder::{BigEndian, ByteOrder};
use std::convert::TryFrom;
use std::{
io,
ops::{Add, AddAssign, Sub, SubAssign},
@ -20,46 +21,100 @@ use std::{
///
/// Reading `MONEY` value in text format is not supported and will cause an error.
///
/// ### `locale_frac_digits`
/// This parameter corresponds to the number of digits after the decimal separator.
///
/// This value must match what Postgres is expecting for the locale set in the database
/// or else the decimal value you see on the client side will not match the `money` value
/// on the server side.
///
/// **For _most_ locales, this value is `2`.**
///
/// If you're not sure what locale your database is set to or how many decimal digits it specifies,
/// you can execute `SHOW lc_monetary;` to get the locale name, and then look it up in this list
/// (you can ignore the `.utf8` prefix):
/// https://lh.2xlibre.net/values/frac_digits/
///
/// If that link is dead and you're on a POSIX-compliant system (Unix, FreeBSD) you can also execute:
///
/// ```sh
/// $ LC_MONETARY=<value returned by `SHOW lc_monetary`> locale -k frac_digits
/// ```
///
/// And the value you want is `N` in `frac_digits=N`. If you have shell access to the database
/// server you should execute it there as available locales may differ between machines.
///
/// Note that if `frac_digits` for the locale is outside the range `[0, 10]`, Postgres assumes
/// it's a sentinel value and defaults to 2:
/// https://github.com/postgres/postgres/blob/master/src/backend/utils/adt/cash.c#L114-L123
///
/// [`MONEY`]: https://www.postgresql.org/docs/current/datatype-money.html
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct PgMoney(pub i64);
pub struct PgMoney(
/// The raw integer value sent over the wire; for locales with `frac_digits=2` (i.e. most
/// of them), this will be the value in whole cents.
///
/// E.g. for `select '$123.45'::money` with a locale of `en_US` (`frac_digits=2`),
/// this will be `12345`.
///
/// If the currency of your locale does not have fractional units, e.g. Yen, then this will
/// just be the units of the currency.
///
/// See the type-level docs for an explanation of `locale_frac_units`.
pub i64,
);
impl PgMoney {
/// Convert the money value into a [`BigDecimal`] using the correct
/// precision defined in the PostgreSQL settings. The default precision is
/// two.
/// Convert the money value into a [`BigDecimal`] using `locale_frac_digits`.
///
/// See the type-level docs for an explanation of `locale_frac_digits`.
///
/// [`BigDecimal`]: crate::types::BigDecimal
#[cfg(feature = "bigdecimal")]
pub fn to_bigdecimal(self, scale: i64) -> bigdecimal::BigDecimal {
pub fn to_bigdecimal(self, locale_frac_digits: i64) -> bigdecimal::BigDecimal {
let digits = num_bigint::BigInt::from(self.0);
bigdecimal::BigDecimal::new(digits, scale)
bigdecimal::BigDecimal::new(digits, locale_frac_digits)
}
/// Convert the money value into a [`Decimal`] using the correct precision
/// defined in the PostgreSQL settings. The default precision is two.
/// Convert the money value into a [`Decimal`] using `locale_frac_digits`.
///
/// See the type-level docs for an explanation of `locale_frac_digits`.
///
/// [`Decimal`]: crate::types::Decimal
#[cfg(feature = "decimal")]
pub fn to_decimal(self, scale: u32) -> rust_decimal::Decimal {
rust_decimal::Decimal::new(self.0, scale)
pub fn to_decimal(self, locale_frac_digits: u32) -> rust_decimal::Decimal {
rust_decimal::Decimal::new(self.0, locale_frac_digits)
}
/// Convert a [`Decimal`] value into money using the correct precision
/// defined in the PostgreSQL settings. The default precision is two.
/// Convert a [`Decimal`] value into money using `locale_frac_digits`.
///
/// Conversion may involve a loss of precision.
/// See the type-level docs for an explanation of `locale_frac_digits`.
///
/// Note that `Decimal` has 96 bits of precision, but `PgMoney` only has 63 plus the sign bit.
/// If the value is larger than 63 bits it will be truncated.
///
/// [`Decimal`]: crate::types::Decimal
#[cfg(feature = "decimal")]
pub fn from_decimal(decimal: rust_decimal::Decimal, scale: u32) -> Self {
let cents = (decimal * rust_decimal::Decimal::new(10i64.pow(scale), 0)).round();
pub fn from_decimal(mut decimal: rust_decimal::Decimal, locale_frac_digits: u32) -> Self {
// this is all we need to convert to our expected locale's `frac_digits`
decimal.rescale(locale_frac_digits);
let mut buf: [u8; 8] = [0; 8];
buf.copy_from_slice(&cents.serialize()[4..12]);
/// a mask to bitwise-AND with an `i64` to zero the sign bit
const SIGN_MASK: i64 = i64::MAX;
Self(i64::from_le_bytes(buf))
let is_negative = decimal.is_sign_negative();
let serialized = decimal.serialize();
// interpret bytes `4..12` as an i64, ignoring the sign bit
// this is where truncation occurs
let value = i64::from_le_bytes(
*<&[u8; 8]>::try_from(&serialized[4..12])
.expect("BUG: slice of serialized should be 8 bytes"),
) & SIGN_MASK; // zero out the sign bit
// negate if necessary
Self(if is_negative { -value } else { value })
}
/// Convert a [`BigDecimal`](crate::types::BigDecimal) value into money using the correct precision
@ -67,12 +122,14 @@ impl PgMoney {
#[cfg(feature = "bigdecimal")]
pub fn from_bigdecimal(
decimal: bigdecimal::BigDecimal,
scale: u32,
locale_frac_digits: u32,
) -> Result<Self, BoxDynError> {
use bigdecimal::ToPrimitive;
let multiplier =
bigdecimal::BigDecimal::new(num_bigint::BigInt::from(10i128.pow(scale)), 0);
let multiplier = bigdecimal::BigDecimal::new(
num_bigint::BigInt::from(10i128.pow(locale_frac_digits)),
0,
);
let cents = decimal * multiplier;
@ -277,9 +334,25 @@ mod tests {
#[test]
#[cfg(feature = "decimal")]
fn conversion_from_decimal_works() {
let dec = rust_decimal::Decimal::new(12345, 2);
assert_eq!(
PgMoney(12345),
PgMoney::from_decimal(rust_decimal::Decimal::new(12345, 2), 2)
);
assert_eq!(PgMoney(12345), PgMoney::from_decimal(dec, 2));
assert_eq!(
PgMoney(12345),
PgMoney::from_decimal(rust_decimal::Decimal::new(123450, 3), 2)
);
assert_eq!(
PgMoney(-12345),
PgMoney::from_decimal(rust_decimal::Decimal::new(-123450, 3), 2)
);
assert_eq!(
PgMoney(-12300),
PgMoney::from_decimal(rust_decimal::Decimal::new(-123, 0), 2)
);
}
#[test]