From a8544fd503f071b3d9882f4dc429d7deabbceecb Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Wed, 21 Jul 2021 16:29:20 -0700 Subject: [PATCH] fix(pg_money): handle negative values correctly in `PgMoney::from_decimal()` (#1334) closes #1321 --- sqlx-core/src/postgres/types/money.rs | 119 +++++++++++++++++++++----- 1 file changed, 96 insertions(+), 23 deletions(-) diff --git a/sqlx-core/src/postgres/types/money.rs b/sqlx-core/src/postgres/types/money.rs index 2ae47dcd..9b9178bb 100644 --- a/sqlx-core/src/postgres/types/money.rs +++ b/sqlx-core/src/postgres/types/money.rs @@ -6,6 +6,7 @@ use crate::{ types::Type, }; use byteorder::{BigEndian, ByteOrder}; +use std::convert::TryFrom; use std::{ io, ops::{Add, AddAssign, Sub, SubAssign}, @@ -20,46 +21,100 @@ use std::{ /// /// Reading `MONEY` value in text format is not supported and will cause an error. /// +/// ### `locale_frac_digits` +/// This parameter corresponds to the number of digits after the decimal separator. +/// +/// This value must match what Postgres is expecting for the locale set in the database +/// or else the decimal value you see on the client side will not match the `money` value +/// on the server side. +/// +/// **For _most_ locales, this value is `2`.** +/// +/// If you're not sure what locale your database is set to or how many decimal digits it specifies, +/// you can execute `SHOW lc_monetary;` to get the locale name, and then look it up in this list +/// (you can ignore the `.utf8` prefix): +/// https://lh.2xlibre.net/values/frac_digits/ +/// +/// If that link is dead and you're on a POSIX-compliant system (Unix, FreeBSD) you can also execute: +/// +/// ```sh +/// $ LC_MONETARY= locale -k frac_digits +/// ``` +/// +/// And the value you want is `N` in `frac_digits=N`. If you have shell access to the database +/// server you should execute it there as available locales may differ between machines. +/// +/// Note that if `frac_digits` for the locale is outside the range `[0, 10]`, Postgres assumes +/// it's a sentinel value and defaults to 2: +/// https://github.com/postgres/postgres/blob/master/src/backend/utils/adt/cash.c#L114-L123 +/// /// [`MONEY`]: https://www.postgresql.org/docs/current/datatype-money.html #[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub struct PgMoney(pub i64); +pub struct PgMoney( + /// The raw integer value sent over the wire; for locales with `frac_digits=2` (i.e. most + /// of them), this will be the value in whole cents. + /// + /// E.g. for `select '$123.45'::money` with a locale of `en_US` (`frac_digits=2`), + /// this will be `12345`. + /// + /// If the currency of your locale does not have fractional units, e.g. Yen, then this will + /// just be the units of the currency. + /// + /// See the type-level docs for an explanation of `locale_frac_units`. + pub i64, +); impl PgMoney { - /// Convert the money value into a [`BigDecimal`] using the correct - /// precision defined in the PostgreSQL settings. The default precision is - /// two. + /// Convert the money value into a [`BigDecimal`] using `locale_frac_digits`. + /// + /// See the type-level docs for an explanation of `locale_frac_digits`. /// /// [`BigDecimal`]: crate::types::BigDecimal #[cfg(feature = "bigdecimal")] - pub fn to_bigdecimal(self, scale: i64) -> bigdecimal::BigDecimal { + pub fn to_bigdecimal(self, locale_frac_digits: i64) -> bigdecimal::BigDecimal { let digits = num_bigint::BigInt::from(self.0); - bigdecimal::BigDecimal::new(digits, scale) + bigdecimal::BigDecimal::new(digits, locale_frac_digits) } - /// Convert the money value into a [`Decimal`] using the correct precision - /// defined in the PostgreSQL settings. The default precision is two. + /// Convert the money value into a [`Decimal`] using `locale_frac_digits`. + /// + /// See the type-level docs for an explanation of `locale_frac_digits`. /// /// [`Decimal`]: crate::types::Decimal #[cfg(feature = "decimal")] - pub fn to_decimal(self, scale: u32) -> rust_decimal::Decimal { - rust_decimal::Decimal::new(self.0, scale) + pub fn to_decimal(self, locale_frac_digits: u32) -> rust_decimal::Decimal { + rust_decimal::Decimal::new(self.0, locale_frac_digits) } - /// Convert a [`Decimal`] value into money using the correct precision - /// defined in the PostgreSQL settings. The default precision is two. + /// Convert a [`Decimal`] value into money using `locale_frac_digits`. /// - /// Conversion may involve a loss of precision. + /// See the type-level docs for an explanation of `locale_frac_digits`. + /// + /// Note that `Decimal` has 96 bits of precision, but `PgMoney` only has 63 plus the sign bit. + /// If the value is larger than 63 bits it will be truncated. /// /// [`Decimal`]: crate::types::Decimal #[cfg(feature = "decimal")] - pub fn from_decimal(decimal: rust_decimal::Decimal, scale: u32) -> Self { - let cents = (decimal * rust_decimal::Decimal::new(10i64.pow(scale), 0)).round(); + pub fn from_decimal(mut decimal: rust_decimal::Decimal, locale_frac_digits: u32) -> Self { + // this is all we need to convert to our expected locale's `frac_digits` + decimal.rescale(locale_frac_digits); - let mut buf: [u8; 8] = [0; 8]; - buf.copy_from_slice(¢s.serialize()[4..12]); + /// a mask to bitwise-AND with an `i64` to zero the sign bit + const SIGN_MASK: i64 = i64::MAX; - Self(i64::from_le_bytes(buf)) + let is_negative = decimal.is_sign_negative(); + let serialized = decimal.serialize(); + + // interpret bytes `4..12` as an i64, ignoring the sign bit + // this is where truncation occurs + let value = i64::from_le_bytes( + *<&[u8; 8]>::try_from(&serialized[4..12]) + .expect("BUG: slice of serialized should be 8 bytes"), + ) & SIGN_MASK; // zero out the sign bit + + // negate if necessary + Self(if is_negative { -value } else { value }) } /// Convert a [`BigDecimal`](crate::types::BigDecimal) value into money using the correct precision @@ -67,12 +122,14 @@ impl PgMoney { #[cfg(feature = "bigdecimal")] pub fn from_bigdecimal( decimal: bigdecimal::BigDecimal, - scale: u32, + locale_frac_digits: u32, ) -> Result { use bigdecimal::ToPrimitive; - let multiplier = - bigdecimal::BigDecimal::new(num_bigint::BigInt::from(10i128.pow(scale)), 0); + let multiplier = bigdecimal::BigDecimal::new( + num_bigint::BigInt::from(10i128.pow(locale_frac_digits)), + 0, + ); let cents = decimal * multiplier; @@ -277,9 +334,25 @@ mod tests { #[test] #[cfg(feature = "decimal")] fn conversion_from_decimal_works() { - let dec = rust_decimal::Decimal::new(12345, 2); + assert_eq!( + PgMoney(12345), + PgMoney::from_decimal(rust_decimal::Decimal::new(12345, 2), 2) + ); - assert_eq!(PgMoney(12345), PgMoney::from_decimal(dec, 2)); + assert_eq!( + PgMoney(12345), + PgMoney::from_decimal(rust_decimal::Decimal::new(123450, 3), 2) + ); + + assert_eq!( + PgMoney(-12345), + PgMoney::from_decimal(rust_decimal::Decimal::new(-123450, 3), 2) + ); + + assert_eq!( + PgMoney(-12300), + PgMoney::from_decimal(rust_decimal::Decimal::new(-123, 0), 2) + ); } #[test]