Postgres: implement numeric and BigDecimal support

This commit is contained in:
Austin Bonander 2020-03-17 16:43:50 -07:00
parent d1af2fe1b0
commit 94c40b3eb7
10 changed files with 508 additions and 11 deletions

View File

@ -46,6 +46,7 @@ mysql = [ "sqlx-core/mysql", "sqlx-macros/mysql" ]
sqlite = [ "sqlx-core/sqlite", "sqlx-macros/sqlite" ]
# types
bigdecimal = ["sqlx-core/bigdecimal_bigint", "sqlx-macros/bigdecimal"]
chrono = [ "sqlx-core/chrono", "sqlx-macros/chrono" ]
uuid = [ "sqlx-core/uuid", "sqlx-macros/uuid" ]

View File

@ -15,6 +15,9 @@ authors = [
[features]
default = [ "runtime-async-std" ]
unstable = []
# we need a feature which activates `num-bigint` as well because
# `bigdecimal` uses types from it but does not reexport (tsk tsk)
bigdecimal_bigint = ["bigdecimal", "num-bigint"]
postgres = [ "md-5", "sha2", "base64", "sha-1", "rand", "hmac", "futures-channel/sink", "futures-util/sink" ]
mysql = [ "sha-1", "sha2", "generic-array", "num-bigint", "base64", "digest", "rand" ]
sqlite = [ "libsqlite3-sys" ]
@ -27,6 +30,7 @@ async-native-tls = { version = "0.3.2", default-features = false, optional = tru
async-std = { version = "1.5.0", features = [ "unstable" ], optional = true }
async-stream = { version = "0.2.1", default-features = false }
base64 = { version = "0.12.0", default-features = false, optional = true, features = [ "std" ] }
bigdecimal = { version = "0.1.0", optional = true }
bitflags = { version = "1.2.1", default-features = false }
byteorder = { version = "1.3.4", default-features = false, features = [ "std" ] }
chrono = { version = "0.4.10", default-features = false, features = [ "clock" ], optional = true }

View File

@ -1,6 +1,12 @@
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TypeId(pub(crate) u32);
// DEVELOPER PRO TIP: find builtin type OIDs easily by grepping this file
// https://github.com/postgres/postgres/blob/master/src/include/catalog/pg_type.dat
//
// If you have Postgres running locally you can also try
// SELECT oid, typarray FROM pg_type where typname = '<type name>'
#[allow(dead_code)]
impl TypeId {
// Scalar
@ -14,6 +20,8 @@ impl TypeId {
pub(crate) const FLOAT4: TypeId = TypeId(700);
pub(crate) const FLOAT8: TypeId = TypeId(701);
pub(crate) const NUMERIC: TypeId = TypeId(1700);
pub(crate) const TEXT: TypeId = TypeId(25);
pub(crate) const DATE: TypeId = TypeId(1082);
@ -38,6 +46,8 @@ impl TypeId {
pub(crate) const ARRAY_TEXT: TypeId = TypeId(1009);
pub(crate) const ARRAY_NUMERIC: TypeId = TypeId(1700);
pub(crate) const ARRAY_DATE: TypeId = TypeId(1182);
pub(crate) const ARRAY_TIME: TypeId = TypeId(1183);
pub(crate) const ARRAY_TIMESTAMP: TypeId = TypeId(1115);

View File

@ -0,0 +1,279 @@
use bigdecimal::BigDecimal;
use num_bigint::{BigInt, Sign};
use std::convert::{TryFrom, TryInto};
use super::numeric::{PgNumeric, PgNumericSign};
use crate::database::{Database, HasRawValue};
use crate::encode::Encode;
use crate::postgres::{PgValue, Postgres};
use crate::types::Type;
use crate::decode::Decode;
use crate::Error;
use std::cmp;
impl Type<Postgres> for BigDecimal {
fn type_info() -> <Postgres as Database>::TypeInfo {
<PgNumeric as Type<Postgres>>::type_info()
}
}
impl TryFrom<&'_ BigDecimal> for PgNumeric {
type Error = std::num::TryFromIntError;
fn try_from(bd: &'_ BigDecimal) -> Result<Self, Self::Error> {
let base_10_to_10000 = |chunk: &[u8]| chunk.iter().fold(0i16, |a, &d| a * 10 + d as i16);
// this implementation unfortunately has a number of redundant copies because BigDecimal
// doesn't give us even immutable access to its internal representation, and neither
// does `BigInt` or `BigUint`
let (bigint, exp) = bd.as_bigint_and_exponent();
// routine is specifically optimized for base-10
let (sign, base_10) = bigint.to_radix_be(10);
// weight is positive power of 10000
// exp is the negative power of 10
let weight_10 = base_10.len() as i64 - exp;
// scale is only nonzero when we have fractional digits
let scale: i16 = cmp::max(0, exp).try_into()?;
// there's an implicit +1 offset in the interpretation
let weight: i16 = if weight_10 <= 0 {
weight_10 / 4 - 1
} else {
weight_10 / 4
}
.try_into()?;
let digits_len = if base_10.len() % 4 != 0 {
base_10.len() / 4 + 1
} else {
base_10.len() / 4
};
let offset = if weight_10 < 0 {
4 - (-weight_10) % 4
} else {
weight_10 % 4
} as usize;
let mut digits = Vec::with_capacity(digits_len);
if let Some(first) = base_10.get(..offset) {
if offset != 0 {
digits.push(base_10_to_10000(first));
}
}
if let Some(rest) = base_10.get(offset..) {
digits.extend(
rest.chunks(4)
.map(|chunk| base_10_to_10000(chunk) * 10i16.pow(4 - chunk.len() as u32)),
);
}
while let Some(&0) = digits.last() {
digits.pop();
}
Ok(PgNumeric {
sign: match sign {
Sign::Plus | Sign::NoSign => PgNumericSign::Positive,
Sign::Minus => PgNumericSign::Negative,
},
scale,
weight,
digits,
})
}
}
impl TryFrom<PgNumeric> for BigDecimal {
type Error = crate::Error;
fn try_from(numeric: PgNumeric) -> crate::Result<Self> {
let sign = match numeric.sign {
_ if numeric.digits.is_empty() => Sign::NoSign,
PgNumericSign::Positive => Sign::Plus,
PgNumericSign::Negative => Sign::Minus,
PgNumericSign::NotANumber => {
return Err(crate::Error::Decode(
"BigDecimal does not support NaN values".into(),
))
}
};
// `scale` is effectively the number of places left to shift the decimal point
// weight is 0 if the decimal point falls after the first base-10000 digit
let scale = (numeric.digits.len() as i64 - numeric.weight as i64 - 1) * 4;
// no optimized algorithm for base-10 so use base-100 for faster processing
let mut cents = Vec::with_capacity(numeric.digits.len() * 2);
for digit in &numeric.digits {
cents.push((digit / 100) as u8);
cents.push((digit % 100) as u8);
}
let bigint = BigInt::from_radix_be(sign, &cents, 100)
.expect("BUG digit outside of given radix, check math above");
dbg!(&numeric);
Ok(BigDecimal::new(bigint, scale))
}
}
/// ### Panics
/// If this `BigDecimal` cannot be represented by [PgNumeric].
impl Encode<Postgres> for BigDecimal {
fn encode(&self, buf: &mut <Postgres as Database>::RawBuffer) {
PgNumeric::try_from(self)
.expect("BigDecimal magnitude too great for Postgres NUMERIC type")
.encode(buf);
}
fn size_hint(&self) -> usize {
// BigDecimal::digits() gives us base-10 digits, so we divide by 4 to get base-10000 digits
// and since this is just a hint we just always round up
8 + (self.digits() / 4 + 1) as usize * 2
}
}
impl Decode<'_, Postgres> for BigDecimal {
fn decode(value: Option<PgValue>) -> crate::Result<Self> {
match value.try_into()? {
PgValue::Binary(binary) => PgNumeric::from_bytes(binary)?.try_into(),
PgValue::Text(text) => text
.parse::<BigDecimal>()
.map_err(|e| crate::Error::Decode(e.into())),
}
}
}
#[test]
fn test_bigdecimal_to_pgnumeric() {
let one: BigDecimal = "1".parse().unwrap();
assert_eq!(
PgNumeric::try_from(&one).unwrap(),
PgNumeric {
sign: PgNumericSign::Positive,
scale: 0,
weight: 0,
digits: vec![1]
}
);
let ten: BigDecimal = "10".parse().unwrap();
assert_eq!(
PgNumeric::try_from(&ten).unwrap(),
PgNumeric {
sign: PgNumericSign::Positive,
scale: 0,
weight: 0,
digits: vec![10]
}
);
let one_hundred: BigDecimal = "100".parse().unwrap();
assert_eq!(
PgNumeric::try_from(&one_hundred).unwrap(),
PgNumeric {
sign: PgNumericSign::Positive,
scale: 0,
weight: 0,
digits: vec![100]
}
);
// BigDecimal doesn't normalize here
let ten_thousand: BigDecimal = "10000".parse().unwrap();
assert_eq!(
PgNumeric::try_from(&ten_thousand).unwrap(),
PgNumeric {
sign: PgNumericSign::Positive,
scale: 0,
weight: 1,
digits: vec![1]
}
);
let two_digits: BigDecimal = "12345".parse().unwrap();
assert_eq!(
PgNumeric::try_from(&two_digits).unwrap(),
PgNumeric {
sign: PgNumericSign::Positive,
scale: 0,
weight: 1,
digits: vec![1, 2345]
}
);
let one_tenth: BigDecimal = "0.1".parse().unwrap();
assert_eq!(
PgNumeric::try_from(&one_tenth).unwrap(),
PgNumeric {
sign: PgNumericSign::Positive,
scale: 1,
weight: -1,
digits: vec![1000]
}
);
let decimal: BigDecimal = "1.2345".parse().unwrap();
assert_eq!(
PgNumeric::try_from(&decimal).unwrap(),
PgNumeric {
sign: PgNumericSign::Positive,
scale: 4,
weight: 0,
digits: vec![1, 2345]
}
);
let decimal: BigDecimal = "0.12345".parse().unwrap();
assert_eq!(
PgNumeric::try_from(&decimal).unwrap(),
PgNumeric {
sign: PgNumericSign::Positive,
scale: 5,
weight: -1,
digits: vec![1234, 5000]
}
);
let decimal: BigDecimal = "0.01234".parse().unwrap();
assert_eq!(
PgNumeric::try_from(&decimal).unwrap(),
PgNumeric {
sign: PgNumericSign::Positive,
scale: 5,
weight: -1,
digits: vec![0123, 4000]
}
);
let decimal: BigDecimal = "12345.67890".parse().unwrap();
assert_eq!(
PgNumeric::try_from(&decimal).unwrap(),
PgNumeric {
sign: PgNumericSign::Positive,
scale: 5,
weight: 1,
digits: vec![1, 2345, 6789]
}
);
let one_digit_decimal: BigDecimal = "0.00001234".parse().unwrap();
assert_eq!(
PgNumeric::try_from(&one_digit_decimal).unwrap(),
PgNumeric {
sign: PgNumericSign::Positive,
scale: 8,
weight: -2,
digits: vec![1234]
}
);
}

View File

@ -11,11 +11,16 @@ mod bool;
mod bytes;
mod float;
mod int;
mod numeric;
mod record;
mod str;
pub use self::numeric::{PgNumeric, PgNumericSign};
pub use self::record::{PgRecordDecoder, PgRecordEncoder};
#[cfg(feature = "bigdecimal_bigint")]
mod bigdecimal;
#[cfg(feature = "chrono")]
mod chrono;

View File

@ -0,0 +1,113 @@
use byteorder::BigEndian;
use std::convert::TryInto;
use crate::database::{Database, HasRawValue};
use crate::decode::Decode;
use crate::encode::Encode;
use crate::io::{Buf, BufMut};
use crate::postgres::protocol::TypeId;
use crate::postgres::{PgTypeInfo, PgValue, Postgres};
use crate::types::Type;
use crate::Error;
/// Wire representation of a Postgres NUMERIC type
#[derive(Debug, PartialEq, Eq)]
pub struct PgNumeric {
pub sign: PgNumericSign,
pub scale: i16,
pub weight: i16,
pub digits: Vec<i16>,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[repr(i16)]
pub enum PgNumericSign {
Positive = 0x0000,
Negative = 0x4000,
NotANumber = -0x4000, // 0xC000
}
impl PgNumericSign {
fn from_u16(sign: i16) -> crate::Result<PgNumericSign> {
// https://github.com/postgres/postgres/blob/bcd1c3630095e48bc3b1eb0fc8e8c8a7c851eba1/src/backend/utils/adt/numeric.c#L167-L170
match sign {
0x0000 => Ok(PgNumericSign::Positive),
0x4000 => Ok(PgNumericSign::Negative),
-0x4000 => Ok(PgNumericSign::NotANumber),
_ => Err(Error::Decode(
format!("unknown value for PgNumericSign: {:#04X}", sign).into(),
)),
}
}
}
impl Type<Postgres> for PgNumeric {
fn type_info() -> <Postgres as Database>::TypeInfo {
PgTypeInfo::new(TypeId::NUMERIC, "NUMERIC")
}
}
impl PgNumeric {
pub(crate) fn from_bytes(mut bytes: &[u8]) -> crate::Result<Self> {
// https://github.com/postgres/postgres/blob/bcd1c3630095e48bc3b1eb0fc8e8c8a7c851eba1/src/backend/utils/adt/numeric.c#L874
let num_digits = bytes.get_u16::<BigEndian>()?;
let weight = bytes.get_i16::<BigEndian>()?;
let sign = bytes.get_i16::<BigEndian>()?;
let scale = bytes.get_i16::<BigEndian>()?;
let digits: Vec<_> = (0..num_digits)
.map(|_| bytes.get_i16::<BigEndian>())
.collect::<Result<_, _>>()?;
Ok(PgNumeric {
sign: PgNumericSign::from_u16(sign)?,
scale,
weight,
digits,
})
}
}
/// ### Note
/// Receiving `PgNumeric` is only supported for the Postgres binary (prepared statements) protocol.
impl Decode<'_, Postgres> for PgNumeric {
fn decode(value: Option<PgValue>) -> crate::Result<Self> {
if let PgValue::Binary(bytes) = value.try_into()? {
Self::from_bytes(bytes)
} else {
Err(Error::Decode(
format!("`PgNumeric` can only be decoded from the binary protocol").into(),
))
}
}
}
/// ### Panics
///
/// * If `self.digits.len()` overflows `i16`
/// * If any element in `self.digits` is greater than or equal to 10000
impl Encode<Postgres> for PgNumeric {
fn encode(&self, buf: &mut Vec<u8>) {
let digits_len: i16 = self
.digits
.len()
.try_into()
.expect("PgNumeric.digits.len() should not overflow i16");
buf.put_i16::<BigEndian>(digits_len);
buf.put_i16::<BigEndian>(self.weight);
buf.put_i16::<BigEndian>(self.sign as i16);
buf.put_i16::<BigEndian>(self.scale);
for &digit in &self.digits {
assert!(digit < 10000, "PgNumeric digits must be in base-10000");
buf.put_i16::<BigEndian>(digit);
}
}
fn size_hint(&self) -> usize {
// 4 i16's plus digits
8 + self.digits.len() * 2
}
}

View File

@ -14,6 +14,10 @@ pub mod chrono {
pub use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc};
}
#[cfg(feature = "bigdecimal")]
#[cfg_attr(docsrs, doc(cfg(feature = "bigdecimal")))]
pub use bigdecimal::BigDecimal;
pub trait TypeInfo: Debug + Display + Clone {
/// Compares type information to determine if `other` is compatible at the Rust level
/// with `self`.

View File

@ -27,6 +27,7 @@ postgres = [ "sqlx/postgres" ]
sqlite = [ "sqlx/sqlite" ]
# type
bigdecimal = [ "sqlx/bigdecimal_bigint" ]
chrono = [ "sqlx/chrono" ]
uuid = [ "sqlx/uuid" ]

View File

@ -67,14 +67,33 @@ macro_rules! test_prepared_type {
$(
let query = format!($crate::[< $db _query_for_test_prepared_type >]!(), $text);
let rec: (bool, $ty) = sqlx::query_as(&query)
let rec: (bool, String, $ty, $ty) = sqlx::query_as(&query)
.bind($value)
.bind($value)
.bind($value)
.fetch_one(&mut conn)
.await?;
assert!(rec.0, "value returned from server: {:?}", rec.1);
assert!($value == rec.1);
assert!(rec.0,
"DB value mismatch; given value: {:?}\n\
as received: {:?}\n\
as returned: {:?}\n\
round-trip: {:?}",
$value, rec.1, rec.2, rec.3);
assert_eq!($value, rec.2,
"DB value mismatch; given value: {:?}\n\
as received: {:?}\n\
as returned: {:?}\n\
round-trip: {:?}",
$value, rec.1, rec.2, rec.3);
assert_eq!($value, rec.3,
"DB value mismatch; given value: {:?}\n\
as received: {:?}\n\
as returned: {:?}\n\
round-trip: {:?}",
$value, rec.1, rec.2, rec.3);
)+
Ok(())
@ -86,20 +105,20 @@ macro_rules! test_prepared_type {
#[macro_export]
macro_rules! MySql_query_for_test_prepared_type {
() => {
"SELECT {} <=> ?, ? as _1"
"SELECT {0} <=> ?, cast(? as text) as _1, {0} as _2, ? as _3"
};
}
#[macro_export]
macro_rules! Sqlite_query_for_test_prepared_type {
() => {
"SELECT {} is ?, ? as _1"
"SELECT {0} is ?, cast(? as text) as _1, {0} as _2, ? as _3"
};
}
#[macro_export]
macro_rules! Postgres_query_for_test_prepared_type {
() => {
"SELECT {} is not distinct from $1, $2 as _1"
"SELECT {0} is not distinct from $1, $2::text as _1, {0}, $3 as _3"
};
}

View File

@ -1,11 +1,11 @@
use std::sync::atomic::{AtomicU32, Ordering};
use sqlx::decode::Decode;
use sqlx::encode::Encode;
use sqlx::postgres::types::PgRecordEncoder;
use sqlx::postgres::types::{PgNumeric, PgNumericSign, PgRecordDecoder, PgRecordEncoder};
use sqlx::postgres::{PgQueryAs, PgTypeInfo, PgValue};
use sqlx::{Cursor, Executor, Postgres, Row, Type};
use sqlx_core::postgres::types::PgRecordDecoder;
use sqlx_test::{new, test_type};
use std::sync::atomic::{AtomicU32, Ordering};
use sqlx_test::{new, test_prepared_type, test_type};
test_type!(null(
Postgres,
@ -49,6 +49,66 @@ test_type!(bytea(
== vec![0_u8, 0, 0, 0, 0x52]
));
// PgNumeric only works on the wire protocol
test_prepared_type!(numeric(
Postgres,
PgNumeric,
"1::numeric"
== PgNumeric {
sign: PgNumericSign::Positive,
weight: 0,
scale: 0,
digits: vec![1]
},
"1234::numeric"
== PgNumeric {
sign: PgNumericSign::Positive,
weight: 0,
scale: 0,
digits: vec![1234]
},
"10000::numeric"
== PgNumeric {
sign: PgNumericSign::Positive,
weight: 1,
scale: 0,
digits: vec![1]
},
"0.1::numeric"
== PgNumeric {
sign: PgNumericSign::Positive,
weight: -1,
scale: 1,
digits: vec![1000]
},
"0.01234::numeric"
== PgNumeric {
sign: PgNumericSign::Positive,
weight: -1,
scale: 5,
digits: vec![123, 4000]
},
"12.34::numeric"
== PgNumeric {
sign: PgNumericSign::Positive,
weight: 0,
scale: 2,
digits: vec![12, 3400]
}
));
#[cfg(feature = "bigdecimal")]
test_type!(decimal(
Postgres,
sqlx::types::BigDecimal,
"1::numeric" == "1".parse::<sqlx::types::BigDecimal>().unwrap(),
"10000::numeric" == "10000".parse::<sqlx::types::BigDecimal>().unwrap(),
"0.1::numeric" == "0.1".parse::<sqlx::types::BigDecimal>().unwrap(),
"0.01234::numeric" == "0.01234".parse::<sqlx::types::BigDecimal>().unwrap(),
"12.34::numeric" == "12.34".parse::<sqlx::types::BigDecimal>().unwrap(),
"12345.6789::numeric" == "12345.6789".parse::<sqlx::types::BigDecimal>().unwrap()
));
#[cfg(feature = "uuid")]
test_type!(uuid(
Postgres,
@ -61,9 +121,10 @@ test_type!(uuid(
#[cfg(feature = "chrono")]
mod chrono {
use super::*;
use sqlx::types::chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc};
use super::*;
test_type!(chrono_date(
Postgres,
NaiveDate,