refactor: prepare to support postgres ranges

- Remove Type bound from Encode + Decode which removes the defaults 
   for produces() and accepts(). This allows custom type implementations
   to be more flexible.
This commit is contained in:
Caio 2020-06-07 13:19:31 -03:00 committed by Ryan Leckey
parent 98a0de2cfd
commit d4329e98d4
51 changed files with 1056 additions and 47 deletions

View File

@ -2,6 +2,7 @@
use crate::database::{Database, HasArguments};
use crate::encode::Encode;
use crate::types::Type;
/// A tuple of arguments to be sent to the database.
pub trait Arguments<'q>: Send + Sized + Default {
@ -14,7 +15,7 @@ pub trait Arguments<'q>: Send + Sized + Default {
/// Add the value to the end of the arguments.
fn add<T>(&mut self, value: T)
where
T: 'q + Encode<'q, Self::Database>;
T: 'q + Encode<'q, Self::Database> + Type<Self::Database>;
}
pub trait IntoArguments<'q, DB: HasArguments<'q>>: Sized + Send {

View File

@ -2,16 +2,13 @@
use crate::database::{Database, HasValueRef};
use crate::error::BoxDynError;
use crate::types::Type;
use crate::value::ValueRef;
/// A type that can be decoded from the database.
pub trait Decode<'r, DB: Database>: Sized + Type<DB> {
pub trait Decode<'r, DB: Database>: Sized {
/// Determines if a value of this type can be created from a value with the
/// given type information.
fn accepts(ty: &DB::TypeInfo) -> bool {
*ty == Self::type_info()
}
fn accepts(ty: &DB::TypeInfo) -> bool;
/// Decode a new value of this type using a raw value from the database.
fn decode(value: <DB as HasValueRef<'r>>::ValueRef) -> Result<Self, BoxDynError>;

View File

@ -2,7 +2,6 @@
use std::mem;
use crate::database::{Database, HasArguments};
use crate::types::Type;
/// The return type of [Encode::encode].
pub enum IsNull {
@ -16,11 +15,7 @@ pub enum IsNull {
}
/// Encode a single value to be sent to the database.
pub trait Encode<'q, DB: Database>: Type<DB> {
fn produces(&self) -> DB::TypeInfo {
Self::type_info()
}
pub trait Encode<'q, DB: Database> {
/// Writes the value of `self` into `buf` in the expected format for the database.
#[must_use]
fn encode(self, buf: &mut <DB as HasArguments<'q>>::ArgumentBuffer) -> IsNull
@ -37,6 +32,10 @@ pub trait Encode<'q, DB: Database>: Type<DB> {
#[must_use]
fn encode_by_ref(&self, buf: &mut <DB as HasArguments<'q>>::ArgumentBuffer) -> IsNull;
fn produces(&self) -> Option<DB::TypeInfo> {
None
}
#[inline]
fn size_hint(&self) -> usize {
mem::size_of_val(self)
@ -48,7 +47,7 @@ where
T: Encode<'q, DB>,
{
#[inline]
fn produces(&self) -> DB::TypeInfo {
fn produces(&self) -> Option<DB::TypeInfo> {
(**self).produces()
}
@ -68,18 +67,18 @@ where
}
}
#[allow(unused_macros)]
macro_rules! impl_encode_for_option {
($DB:ident) => {
impl<'q, T: 'q + crate::encode::Encode<'q, $DB>> crate::encode::Encode<'q, $DB>
for Option<T>
impl<'q, T> crate::encode::Encode<'q, $DB> for Option<T>
where
T: crate::encode::Encode<'q, $DB> + crate::types::Type<$DB> + 'q,
{
#[inline]
fn produces(&self) -> <$DB as crate::database::Database>::TypeInfo {
fn produces(&self) -> Option<<$DB as crate::database::Database>::TypeInfo> {
if let Some(v) = self {
v.produces()
} else {
T::type_info()
T::type_info().into()
}
}

View File

@ -33,7 +33,7 @@ macro_rules! impl_from_row_for_tuple {
impl<'r, R, $($T,)+> FromRow<'r, R> for ($($T,)+)
where
R: Row,
$($T: crate::decode::Decode<'r, R::Database>,)+
$($T: crate::decode::Decode<'r, R::Database> + crate::types::Type<R::Database>,)+
{
#[inline]
fn from_row(row: &'r R) -> Result<Self, Error> {

View File

@ -7,12 +7,12 @@ mod int;
mod str;
impl<'q, T: 'q + Encode<'q, Mssql>> Encode<'q, Mssql> for Option<T> {
fn produces(&self) -> MssqlTypeInfo {
fn produces(&self) -> Option<MssqlTypeInfo> {
if let Some(v) = self {
v.produces()
} else {
// MSSQL requires a special NULL type ID
MssqlTypeInfo(TypeInfo::new(DataType::Null, 0))
Some(MssqlTypeInfo(TypeInfo::new(DataType::Null, 0)))
}
}

View File

@ -19,9 +19,9 @@ impl Type<Mssql> for String {
}
impl Encode<'_, Mssql> for &'_ str {
fn produces(&self) -> MssqlTypeInfo {
fn produces(&self) -> Option<MssqlTypeInfo> {
// an empty string needs to be encoded as `nvarchar(2)`
MssqlTypeInfo(TypeInfo {
Some(MssqlTypeInfo(TypeInfo {
ty: DataType::NVarChar,
size: ((self.len() * 2) as u32).max(2),
scale: 0,
@ -34,7 +34,7 @@ impl Encode<'_, Mssql> for &'_ str {
sort: 52,
version: 0,
}),
})
}))
}
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
@ -45,7 +45,7 @@ impl Encode<'_, Mssql> for &'_ str {
}
impl Encode<'_, Mssql> for String {
fn produces(&self) -> MssqlTypeInfo {
fn produces(&self) -> Option<MssqlTypeInfo> {
<&str as Encode<Mssql>>::produces(&self.as_str())
}

View File

@ -3,6 +3,7 @@ use std::ops::{Deref, DerefMut};
use crate::arguments::Arguments;
use crate::encode::{Encode, IsNull};
use crate::mysql::{MySql, MySqlTypeInfo};
use crate::types::Type;
/// Implementation of [`Arguments`] for MySQL.
#[derive(Debug, Default)]
@ -22,9 +23,9 @@ impl<'q> Arguments<'q> for MySqlArguments {
fn add<T>(&mut self, value: T)
where
T: Encode<'q, Self::Database>,
T: Encode<'q, Self::Database> + Type<Self::Database>,
{
let ty = value.produces();
let ty = value.produces().unwrap_or_else(T::type_info);
let index = self.types.len();
self.types.push(ty);

View File

@ -20,9 +20,17 @@ impl Encode<'_, MySql> for BigDecimal {
IsNull::No
}
fn produces(&self) -> Option<MySqlTypeInfo> {
<Self as Type<MySql>>::type_info().into()
}
}
impl Decode<'_, MySql> for BigDecimal {
fn accepts(ty: &MySqlTypeInfo) -> bool {
*ty == <Self as Type<MySql>>::type_info()
}
fn decode(value: MySqlValueRef<'_>) -> Result<Self, BoxDynError> {
Ok(value.as_str()?.parse()?)
}

View File

@ -15,6 +15,10 @@ impl Encode<'_, MySql> for bool {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
<i8 as Encode<MySql>>::encode(*self as i8, buf)
}
fn produces(&self) -> Option<MySqlTypeInfo> {
<Self as Type<MySql>>::type_info().into()
}
}
impl Decode<'_, MySql> for bool {

View File

@ -18,6 +18,10 @@ impl Encode<'_, MySql> for &'_ [u8] {
IsNull::No
}
fn produces(&self) -> Option<MySqlTypeInfo> {
<Self as Type<MySql>>::type_info().into()
}
}
impl<'r> Decode<'r, MySql> for &'r [u8] {
@ -50,6 +54,10 @@ impl Encode<'_, MySql> for Vec<u8> {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
<&[u8] as Encode<MySql>>::encode(&**self, buf)
}
fn produces(&self) -> Option<MySqlTypeInfo> {
<Self as Type<MySql>>::type_info().into()
}
}
impl Decode<'_, MySql> for Vec<u8> {

View File

@ -21,6 +21,10 @@ impl Encode<'_, MySql> for DateTime<Utc> {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
Encode::<MySql>::encode(&self.naive_utc(), buf)
}
fn produces(&self) -> Option<MySqlTypeInfo> {
<Self as Type<MySql>>::type_info().into()
}
}
impl<'r> Decode<'r, MySql> for DateTime<Utc> {
@ -58,6 +62,10 @@ impl Encode<'_, MySql> for NaiveTime {
IsNull::No
}
fn produces(&self) -> Option<MySqlTypeInfo> {
<Self as Type<MySql>>::type_info().into()
}
fn size_hint(&self) -> usize {
if self.nanosecond() == 0 {
// if micro_seconds is 0, length is 8 and micro_seconds is not sent
@ -70,6 +78,10 @@ impl Encode<'_, MySql> for NaiveTime {
}
impl<'r> Decode<'r, MySql> for NaiveTime {
fn accepts(ty: &MySqlTypeInfo) -> bool {
*ty == <Self as Type<MySql>>::type_info()
}
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
match value.format() {
MySqlValueFormat::Binary => {
@ -112,12 +124,20 @@ impl Encode<'_, MySql> for NaiveDate {
IsNull::No
}
fn produces(&self) -> Option<MySqlTypeInfo> {
<Self as Type<MySql>>::type_info().into()
}
fn size_hint(&self) -> usize {
5
}
}
impl<'r> Decode<'r, MySql> for NaiveDate {
fn accepts(ty: &MySqlTypeInfo) -> bool {
*ty == <Self as Type<MySql>>::type_info()
}
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
match value.format() {
MySqlValueFormat::Binary => Ok(decode_date(&value.as_bytes()?[1..])),
@ -150,6 +170,10 @@ impl Encode<'_, MySql> for NaiveDateTime {
IsNull::No
}
fn produces(&self) -> Option<MySqlTypeInfo> {
<Self as Type<MySql>>::type_info().into()
}
fn size_hint(&self) -> usize {
// to save space the packet can be compressed:
match (

View File

@ -29,6 +29,10 @@ impl Encode<'_, MySql> for f32 {
IsNull::No
}
fn produces(&self) -> Option<MySqlTypeInfo> {
<Self as Type<MySql>>::type_info().into()
}
}
impl Encode<'_, MySql> for f64 {
@ -37,6 +41,10 @@ impl Encode<'_, MySql> for f64 {
IsNull::No
}
fn produces(&self) -> Option<MySqlTypeInfo> {
<Self as Type<MySql>>::type_info().into()
}
}
impl Decode<'_, MySql> for f32 {

View File

@ -39,6 +39,10 @@ impl Encode<'_, MySql> for i8 {
IsNull::No
}
fn produces(&self) -> Option<MySqlTypeInfo> {
<Self as Type<MySql>>::type_info().into()
}
}
impl Encode<'_, MySql> for i16 {
@ -47,6 +51,10 @@ impl Encode<'_, MySql> for i16 {
IsNull::No
}
fn produces(&self) -> Option<MySqlTypeInfo> {
<Self as Type<MySql>>::type_info().into()
}
}
impl Encode<'_, MySql> for i32 {
@ -55,6 +63,10 @@ impl Encode<'_, MySql> for i32 {
IsNull::No
}
fn produces(&self) -> Option<MySqlTypeInfo> {
<Self as Type<MySql>>::type_info().into()
}
}
impl Encode<'_, MySql> for i64 {
@ -63,6 +75,10 @@ impl Encode<'_, MySql> for i64 {
IsNull::No
}
fn produces(&self) -> Option<MySqlTypeInfo> {
<Self as Type<MySql>>::type_info().into()
}
}
fn int_accepts(ty: &MySqlTypeInfo) -> bool {

View File

@ -27,6 +27,10 @@ where
<&str as Encode<MySql>>::encode(json_string_value.as_str(), buf)
}
fn produces(&self) -> Option<MySqlTypeInfo> {
<Self as Type<MySql>>::type_info().into()
}
}
impl<'r, T> Decode<'r, MySql> for Json<T>

View File

@ -22,6 +22,10 @@ impl Encode<'_, MySql> for &'_ str {
IsNull::No
}
fn produces(&self) -> Option<MySqlTypeInfo> {
<Self as Type<MySql>>::type_info().into()
}
}
impl<'r> Decode<'r, MySql> for &'r str {
@ -54,6 +58,10 @@ impl Encode<'_, MySql> for String {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
<&str as Encode<MySql>>::encode(&**self, buf)
}
fn produces(&self) -> Option<MySqlTypeInfo> {
<Self as Type<MySql>>::type_info().into()
}
}
impl Decode<'_, MySql> for String {

View File

@ -26,6 +26,10 @@ impl Encode<'_, MySql> for OffsetDateTime {
Encode::<MySql>::encode(&primitive_dt, buf)
}
fn produces(&self) -> Option<MySqlTypeInfo> {
<Self as Type<MySql>>::type_info().into()
}
}
impl<'r> Decode<'r, MySql> for OffsetDateTime {
@ -63,6 +67,10 @@ impl Encode<'_, MySql> for Time {
IsNull::No
}
fn produces(&self) -> Option<MySqlTypeInfo> {
<Self as Type<MySql>>::type_info().into()
}
fn size_hint(&self) -> usize {
if self.nanosecond() == 0 {
// if micro_seconds is 0, length is 8 and micro_seconds is not sent
@ -75,6 +83,10 @@ impl Encode<'_, MySql> for Time {
}
impl<'r> Decode<'r, MySql> for Time {
fn accepts(ty: &MySqlTypeInfo) -> bool {
*ty == <Self as Type<MySql>>::type_info()
}
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
match value.format() {
MySqlValueFormat::Binary => {
@ -128,12 +140,20 @@ impl Encode<'_, MySql> for Date {
IsNull::No
}
fn produces(&self) -> Option<MySqlTypeInfo> {
<Self as Type<MySql>>::type_info().into()
}
fn size_hint(&self) -> usize {
5
}
}
impl<'r> Decode<'r, MySql> for Date {
fn accepts(ty: &MySqlTypeInfo) -> bool {
*ty == <Self as Type<MySql>>::type_info()
}
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
match value.format() {
MySqlValueFormat::Binary => decode_date(&value.as_bytes()?[1..]),
@ -165,6 +185,10 @@ impl Encode<'_, MySql> for PrimitiveDateTime {
IsNull::No
}
fn produces(&self) -> Option<MySqlTypeInfo> {
<Self as Type<MySql>>::type_info().into()
}
fn size_hint(&self) -> usize {
// to save space the packet can be compressed:
match (self.hour(), self.minute(), self.second(), self.nanosecond()) {

View File

@ -47,6 +47,10 @@ impl Encode<'_, MySql> for u8 {
IsNull::No
}
fn produces(&self) -> Option<MySqlTypeInfo> {
<Self as Type<MySql>>::type_info().into()
}
}
impl Encode<'_, MySql> for u16 {
@ -55,6 +59,10 @@ impl Encode<'_, MySql> for u16 {
IsNull::No
}
fn produces(&self) -> Option<MySqlTypeInfo> {
<Self as Type<MySql>>::type_info().into()
}
}
impl Encode<'_, MySql> for u32 {
@ -63,6 +71,10 @@ impl Encode<'_, MySql> for u32 {
IsNull::No
}
fn produces(&self) -> Option<MySqlTypeInfo> {
<Self as Type<MySql>>::type_info().into()
}
}
impl Encode<'_, MySql> for u64 {
@ -71,6 +83,10 @@ impl Encode<'_, MySql> for u64 {
IsNull::No
}
fn produces(&self) -> Option<MySqlTypeInfo> {
<Self as Type<MySql>>::type_info().into()
}
}
fn uint_accepts(ty: &MySqlTypeInfo) -> bool {

View File

@ -5,6 +5,7 @@ use crate::encode::{Encode, IsNull};
use crate::error::Error;
use crate::ext::ustr::UStr;
use crate::postgres::{PgConnection, PgTypeInfo, Postgres};
use crate::types::Type;
#[derive(Default)]
pub struct PgArgumentBuffer {
@ -40,10 +41,11 @@ impl<'q> Arguments<'q> for PgArguments {
fn add<T>(&mut self, value: T)
where
T: Encode<'q, Self::Database>,
T: Encode<'q, Self::Database> + Type<Self::Database>,
{
// remember the type information for this value
self.types.push(value.produces());
self.types
.push(value.produces().unwrap_or_else(T::type_info));
// reserve space to write the prefixed length of the value
let offset = self.buffer.len();

View File

@ -138,7 +138,7 @@ impl PgConnection {
b'P' => Err(err_protocol!("pseudo types are unsupported")),
b'R' => Err(err_protocol!("user-defined range types are unsupported")),
b'R' => self.fetch_range_by_oid(oid, name).await,
b'E' => self.fetch_enum_by_oid(oid, name).await,
@ -209,6 +209,27 @@ ORDER BY attnum
})
}
async fn fetch_range_by_oid(&mut self, oid: u32, name: String) -> Result<PgTypeInfo, Error> {
let _: i32 = query_scalar(
r#"
SELECT 1
FROM pg_catalog.pg_range
WHERE rngtypid = $1
"#,
)
.bind(oid)
.fetch_one(self)
.await?;
let pg_type = PgType::try_from_oid(oid).ok_or_else(|| err_protocol!("Trying to retrieve a DB type that doesn't exist in SQLx"))?;
Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
kind: PgTypeKind::Range(PgTypeInfo(pg_type)),
name: name.into(),
oid,
}))))
}
pub(crate) async fn fetch_type_id_by_name(&mut self, name: &str) -> Result<u32, Error> {
if let Some(oid) = self.cache_type_oid.get(name) {
return Ok(*oid);

View File

@ -35,11 +35,15 @@ where
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull {
self.as_slice().encode_by_ref(buf)
}
fn produces(&self) -> Option<PgTypeInfo> {
<Self as Type<Postgres>>::type_info().into()
}
}
impl<'q, T> Encode<'q, Postgres> for &'_ [T]
where
T: Encode<'q, Postgres>,
T: Encode<'q, Postgres> + Type<Postgres>,
Self: Type<Postgres>,
{
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull {
@ -79,6 +83,10 @@ where
IsNull::No
}
fn produces(&self) -> Option<PgTypeInfo> {
<Self as Type<Postgres>>::type_info().into()
}
}
// TODO: Array decoding in PostgreSQL *could* allow 'r (row) lifetime of elements if we can figure
@ -86,9 +94,13 @@ where
impl<'r, T> Decode<'r, Postgres> for Vec<T>
where
T: for<'a> Decode<'a, Postgres>,
T: for<'a> Decode<'a, Postgres> + Type<Postgres>,
Self: Type<Postgres>,
{
fn accepts(ty: &PgTypeInfo) -> bool {
*ty == <Self as Type<Postgres>>::type_info()
}
fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
let element_type_info = T::type_info();
let format = value.format();

View File

@ -157,6 +157,10 @@ impl Encode<'_, Postgres> for BigDecimal {
IsNull::No
}
fn produces(&self) -> Option<PgTypeInfo> {
<Self as Type<Postgres>>::type_info().into()
}
fn size_hint(&self) -> usize {
// BigDecimal::digits() gives us base-10 digits, so we divide by 4 to get base-10000 digits
// and since this is just a hint we just always round up
@ -165,6 +169,10 @@ impl Encode<'_, Postgres> for BigDecimal {
}
impl Decode<'_, Postgres> for BigDecimal {
fn accepts(ty: &PgTypeInfo) -> bool {
*ty == <Self as Type<Postgres>>::type_info()
}
fn decode(value: PgValueRef<'_>) -> Result<Self, BoxDynError> {
match value.format() {
PgValueFormat::Binary => PgNumeric::decode(value.as_bytes()?)?.try_into(),

View File

@ -28,9 +28,17 @@ impl Encode<'_, Postgres> for bool {
IsNull::No
}
fn produces(&self) -> Option<PgTypeInfo> {
<Self as Type<Postgres>>::type_info().into()
}
}
impl Decode<'_, Postgres> for bool {
fn accepts(ty: &PgTypeInfo) -> bool {
*ty == <Self as Type<Postgres>>::type_info()
}
fn decode(value: PgValueRef<'_>) -> Result<Self, BoxDynError> {
Ok(match value.format() {
PgValueFormat::Binary => value.as_bytes()?[0] != 0,

View File

@ -46,15 +46,27 @@ impl Encode<'_, Postgres> for &'_ [u8] {
IsNull::No
}
fn produces(&self) -> Option<PgTypeInfo> {
<Self as Type<Postgres>>::type_info().into()
}
}
impl Encode<'_, Postgres> for Vec<u8> {
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull {
<&[u8] as Encode<Postgres>>::encode(self, buf)
}
fn produces(&self) -> Option<PgTypeInfo> {
<Self as Type<Postgres>>::type_info().into()
}
}
impl<'r> Decode<'r, Postgres> for &'r [u8] {
fn accepts(ty: &PgTypeInfo) -> bool {
*ty == <Self as Type<Postgres>>::type_info()
}
fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
match value.format() {
PgValueFormat::Binary => value.as_bytes(),
@ -66,6 +78,10 @@ impl<'r> Decode<'r, Postgres> for &'r [u8] {
}
impl Decode<'_, Postgres> for Vec<u8> {
fn accepts(ty: &PgTypeInfo) -> bool {
*ty == <Self as Type<Postgres>>::type_info()
}
fn decode(value: PgValueRef<'_>) -> Result<Self, BoxDynError> {
Ok(match value.format() {
PgValueFormat::Binary => value.as_bytes()?.to_owned(),

View File

@ -90,12 +90,20 @@ impl Encode<'_, Postgres> for NaiveTime {
Encode::<Postgres>::encode(&us, buf)
}
fn produces(&self) -> Option<PgTypeInfo> {
<Self as Type<Postgres>>::type_info().into()
}
fn size_hint(&self) -> usize {
mem::size_of::<u64>()
}
}
impl<'r> Decode<'r, Postgres> for NaiveTime {
fn accepts(ty: &PgTypeInfo) -> bool {
*ty == <Self as Type<Postgres>>::type_info()
}
fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
Ok(match value.format() {
PgValueFormat::Binary => {
@ -116,12 +124,20 @@ impl Encode<'_, Postgres> for NaiveDate {
Encode::<Postgres>::encode(&days, buf)
}
fn produces(&self) -> Option<PgTypeInfo> {
<Self as Type<Postgres>>::type_info().into()
}
fn size_hint(&self) -> usize {
mem::size_of::<i32>()
}
}
impl<'r> Decode<'r, Postgres> for NaiveDate {
fn accepts(ty: &PgTypeInfo) -> bool {
*ty == <Self as Type<Postgres>>::type_info()
}
fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
Ok(match value.format() {
PgValueFormat::Binary => {
@ -146,12 +162,20 @@ impl Encode<'_, Postgres> for NaiveDateTime {
Encode::<Postgres>::encode(&us, buf)
}
fn produces(&self) -> Option<PgTypeInfo> {
<Self as Type<Postgres>>::type_info().into()
}
fn size_hint(&self) -> usize {
mem::size_of::<i64>()
}
}
impl<'r> Decode<'r, Postgres> for NaiveDateTime {
fn accepts(ty: &PgTypeInfo) -> bool {
*ty == <Self as Type<Postgres>>::type_info()
}
fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
Ok(match value.format() {
PgValueFormat::Binary => {
@ -184,12 +208,20 @@ impl<Tz: TimeZone> Encode<'_, Postgres> for DateTime<Tz> {
Encode::<Postgres>::encode(self.naive_utc(), buf)
}
fn produces(&self) -> Option<PgTypeInfo> {
<Self as Type<Postgres>>::type_info().into()
}
fn size_hint(&self) -> usize {
mem::size_of::<i64>()
}
}
impl<'r> Decode<'r, Postgres> for DateTime<Local> {
fn accepts(ty: &PgTypeInfo) -> bool {
*ty == <Self as Type<Postgres>>::type_info()
}
fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
let naive = <NaiveDateTime as Decode<Postgres>>::decode(value)?;
Ok(Local.from_utc_datetime(&naive))
@ -197,6 +229,10 @@ impl<'r> Decode<'r, Postgres> for DateTime<Local> {
}
impl<'r> Decode<'r, Postgres> for DateTime<Utc> {
fn accepts(ty: &PgTypeInfo) -> bool {
*ty == <Self as Type<Postgres>>::type_info()
}
fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
let naive = <NaiveDateTime as Decode<Postgres>>::decode(value)?;
Ok(Utc.from_utc_datetime(&naive))

View File

@ -30,9 +30,17 @@ impl Encode<'_, Postgres> for f32 {
IsNull::No
}
fn produces(&self) -> Option<PgTypeInfo> {
<Self as Type<Postgres>>::type_info().into()
}
}
impl Decode<'_, Postgres> for f32 {
fn accepts(ty: &PgTypeInfo) -> bool {
*ty == <Self as Type<Postgres>>::type_info()
}
fn decode(value: PgValueRef<'_>) -> Result<Self, BoxDynError> {
Ok(match value.format() {
PgValueFormat::Binary => BigEndian::read_f32(value.as_bytes()?),
@ -65,9 +73,17 @@ impl Encode<'_, Postgres> for f64 {
IsNull::No
}
fn produces(&self) -> Option<PgTypeInfo> {
<Self as Type<Postgres>>::type_info().into()
}
}
impl Decode<'_, Postgres> for f64 {
fn accepts(ty: &PgTypeInfo) -> bool {
*ty == <Self as Type<Postgres>>::type_info()
}
fn decode(value: PgValueRef<'_>) -> Result<Self, BoxDynError> {
Ok(match value.format() {
PgValueFormat::Binary => BigEndian::read_f64(value.as_bytes()?),

View File

@ -74,6 +74,10 @@ impl Encode<'_, Postgres> for IpNetwork {
IpNetwork::V6(_) => 20,
}
}
fn produces(&self) -> Option<PgTypeInfo> {
<Self as Type<Postgres>>::type_info().into()
}
}
impl Decode<'_, Postgres> for IpNetwork {

View File

@ -43,6 +43,10 @@ where
IsNull::No
}
fn produces(&self) -> Option<PgTypeInfo> {
<Self as Type<Postgres>>::type_info().into()
}
}
impl<'r, T: 'r> Decode<'r, Postgres> for Json<T>

View File

@ -133,6 +133,7 @@ mod bool;
mod bytes;
mod float;
mod num;
mod ranges;
mod record;
mod str;
mod tuple;
@ -158,4 +159,7 @@ mod json;
#[cfg(feature = "ipnetwork")]
mod ipnetwork;
pub use record::{PgRecordDecoder, PgRecordEncoder};
pub use {
ranges::{pg_range::PgRange, pg_ranges::*},
record::{PgRecordDecoder, PgRecordEncoder},
};

View File

@ -30,9 +30,17 @@ impl Encode<'_, Postgres> for i8 {
IsNull::No
}
fn produces(&self) -> Option<PgTypeInfo> {
<Self as Type<Postgres>>::type_info().into()
}
}
impl Decode<'_, Postgres> for i8 {
fn accepts(ty: &PgTypeInfo) -> bool {
*ty == <Self as Type<Postgres>>::type_info()
}
fn decode(value: PgValueRef<'_>) -> Result<Self, BoxDynError> {
// note: in the TEXT encoding, a value of "0" here is encoded as an empty string
Ok(value.as_bytes()?.get(0).copied().unwrap_or_default() as i8)
@ -63,9 +71,17 @@ impl Encode<'_, Postgres> for i16 {
IsNull::No
}
fn produces(&self) -> Option<PgTypeInfo> {
<Self as Type<Postgres>>::type_info().into()
}
}
impl Decode<'_, Postgres> for i16 {
fn accepts(ty: &PgTypeInfo) -> bool {
*ty == <Self as Type<Postgres>>::type_info()
}
fn decode(value: PgValueRef<'_>) -> Result<Self, BoxDynError> {
Ok(match value.format() {
PgValueFormat::Binary => BigEndian::read_i16(value.as_bytes()?),
@ -98,9 +114,17 @@ impl Encode<'_, Postgres> for u32 {
IsNull::No
}
fn produces(&self) -> Option<PgTypeInfo> {
<Self as Type<Postgres>>::type_info().into()
}
}
impl Decode<'_, Postgres> for u32 {
fn accepts(ty: &PgTypeInfo) -> bool {
*ty == <Self as Type<Postgres>>::type_info()
}
fn decode(value: PgValueRef<'_>) -> Result<Self, BoxDynError> {
Ok(match value.format() {
PgValueFormat::Binary => BigEndian::read_u32(value.as_bytes()?),
@ -133,9 +157,17 @@ impl Encode<'_, Postgres> for i32 {
IsNull::No
}
fn produces(&self) -> Option<PgTypeInfo> {
<Self as Type<Postgres>>::type_info().into()
}
}
impl Decode<'_, Postgres> for i32 {
fn accepts(ty: &PgTypeInfo) -> bool {
*ty == <Self as Type<Postgres>>::type_info()
}
fn decode(value: PgValueRef<'_>) -> Result<Self, BoxDynError> {
Ok(match value.format() {
PgValueFormat::Binary => BigEndian::read_i32(value.as_bytes()?),
@ -168,9 +200,17 @@ impl Encode<'_, Postgres> for i64 {
IsNull::No
}
fn produces(&self) -> Option<PgTypeInfo> {
<Self as Type<Postgres>>::type_info().into()
}
}
impl Decode<'_, Postgres> for i64 {
fn accepts(ty: &PgTypeInfo) -> bool {
*ty == <Self as Type<Postgres>>::type_info()
}
fn decode(value: PgValueRef<'_>) -> Result<Self, BoxDynError> {
Ok(match value.format() {
PgValueFormat::Binary => BigEndian::read_i64(value.as_bytes()?),

View File

@ -0,0 +1,87 @@
pub(crate) mod pg_range;
pub(crate) mod pg_ranges;
use crate::{
decode::Decode,
encode::{Encode, IsNull},
postgres::{types::PgRange, PgArgumentBuffer, PgTypeInfo, PgValueRef, Postgres},
types::Type,
};
use core::{
convert::TryInto,
ops::{Range, RangeFrom, RangeInclusive, RangeTo, RangeToInclusive},
};
macro_rules! impl_range {
($range:ident) => {
impl<'a, T> Decode<'a, Postgres> for $range<T>
where
T: for<'b> Decode<'b, Postgres> + Type<Postgres> + 'a,
{
fn accepts(ty: &PgTypeInfo) -> bool {
<PgRange<T> as Decode<'_, Postgres>>::accepts(ty)
}
fn decode(value: PgValueRef<'a>) -> Result<$range<T>, crate::error::BoxDynError> {
let bounds: PgRange<T> = Decode::<Postgres>::decode(value)?;
let rslt = bounds.try_into()?;
Ok(rslt)
}
}
impl<'a, T> Encode<'a, Postgres> for $range<T>
where
T: Clone + for<'b> Encode<'b, Postgres> + 'a,
{
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull {
<PgRange<T> as Encode<'_, Postgres>>::encode(self.clone().into(), buf)
}
}
};
}
impl_range!(Range);
impl_range!(RangeFrom);
impl_range!(RangeInclusive);
impl_range!(RangeTo);
impl_range!(RangeToInclusive);
#[test]
fn test_decode_str_bounds() {
use crate::postgres::type_info::PgType;
const EXC1: Bound<i32> = Bound::Excluded(1);
const EXC2: Bound<i32> = Bound::Excluded(2);
const INC1: Bound<i32> = Bound::Included(1);
const INC2: Bound<i32> = Bound::Included(2);
const UNB: Bound<i32> = Bound::Unbounded;
let check = |s: &str, range_cmp: [Bound<i32>; 2]| {
let pg_value = PgValueRef {
type_info: PgTypeInfo(PgType::Int4Range),
format: PgValueFormat::Text,
value: Some(s.as_bytes()),
row: None,
};
let range: PgRange<i32> = Decode::<Postgres>::decode(pg_value).unwrap();
assert_eq!(Into::<[Bound<i32>; 2]>::into(range), range_cmp);
};
check("(,)", [UNB, UNB]);
check("(,]", [UNB, UNB]);
check("(,2)", [UNB, EXC2]);
check("(,2]", [UNB, INC2]);
check("(1,)", [EXC1, UNB]);
check("(1,]", [EXC1, UNB]);
check("(1,2)", [EXC1, EXC2]);
check("(1,2]", [EXC1, INC2]);
check("[,)", [UNB, UNB]);
check("[,]", [UNB, UNB]);
check("[,2)", [UNB, EXC2]);
check("[,2]", [UNB, INC2]);
check("[1,)", [INC1, UNB]);
check("[1,]", [INC1, UNB]);
check("[1,2)", [INC1, EXC2]);
check("[1,2]", [INC1, INC2]);
}

View File

@ -0,0 +1,388 @@
use crate::{
decode::Decode,
encode::{Encode, IsNull},
postgres::{PgArgumentBuffer, PgTypeInfo, PgValueFormat, PgValueRef, Postgres},
types::Type,
};
use byteorder::{NetworkEndian, ReadBytesExt, WriteBytesExt};
use core::{
convert::TryFrom,
ops::{Bound, Range, RangeBounds, RangeFrom, RangeInclusive, RangeTo, RangeToInclusive},
};
bitflags::bitflags! {
struct RangeFlags: u8 {
const EMPTY = 0x01;
const LB_INC = 0x02;
const UB_INC = 0x04;
const LB_INF = 0x08;
const UB_INF = 0x10;
const LB_NULL = 0x20;
const UB_NULL = 0x40;
const CONTAIN_EMPTY = 0x80;
}
}
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
pub struct PgRange<T> {
pub start: Bound<T>,
pub end: Bound<T>,
}
impl<T> PgRange<T> {
pub fn new(start: Bound<T>, end: Bound<T>) -> Self {
Self {
start,
end
}
}
}
impl<'a, T> Decode<'a, Postgres> for PgRange<T>
where
T: for<'b> Decode<'b, Postgres> + Type<Postgres> + 'a,
{
fn accepts(ty: &PgTypeInfo) -> bool {
[
PgTypeInfo::INT4_RANGE,
PgTypeInfo::NUM_RANGE,
PgTypeInfo::TS_RANGE,
PgTypeInfo::TSTZ_RANGE,
PgTypeInfo::DATE_RANGE,
PgTypeInfo::INT8_RANGE,
]
.contains(ty)
}
fn decode(value: PgValueRef<'a>) -> Result<PgRange<T>, crate::error::BoxDynError> {
match value.format() {
PgValueFormat::Binary => {
decode_binary(value.as_bytes()?, value.format, value.type_info)
}
PgValueFormat::Text => decode_str(value.as_str()?, value.format(), value.type_info),
}
}
}
impl<'a, T> Encode<'a, Postgres> for PgRange<T>
where
T: for<'b> Encode<'b, Postgres> + 'a,
{
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull {
let mut flags = match self.start {
Bound::Included(_) => RangeFlags::LB_INC,
Bound::Excluded(_) => RangeFlags::empty(),
Bound::Unbounded => RangeFlags::LB_INF,
};
flags |= match self.end {
Bound::Included(_) => RangeFlags::UB_INC,
Bound::Excluded(_) => RangeFlags::empty(),
Bound::Unbounded => RangeFlags::UB_INF,
};
buf.write_u8(flags.bits()).unwrap();
let mut write = |bound: &Bound<T>| -> IsNull {
match bound {
Bound::Included(ref value) | Bound::Excluded(ref value) => {
buf.write_u32::<NetworkEndian>(0).unwrap();
let prev = buf.len();
if let IsNull::Yes = Encode::<Postgres>::encode(value, buf) {
return IsNull::Yes;
}
let len = buf.len() - prev;
buf[prev - 4..prev].copy_from_slice(&(len as u32).to_be_bytes());
}
Bound::Unbounded => {}
}
IsNull::No
};
if let IsNull::Yes = write(&self.start) {
return IsNull::Yes;
}
write(&self.end)
}
}
impl<T> From<[Bound<T>; 2]> for PgRange<T> {
fn from(from: [Bound<T>; 2]) -> Self {
let [start, end] = from;
Self { start, end }
}
}
impl<T> From<(Bound<T>, Bound<T>)> for PgRange<T> {
fn from(from: (Bound<T>, Bound<T>)) -> Self {
Self {
start: from.0,
end: from.1,
}
}
}
impl<T> From<PgRange<T>> for [Bound<T>; 2] {
fn from(from: PgRange<T>) -> Self {
[from.start, from.end]
}
}
impl<T> From<PgRange<T>> for (Bound<T>, Bound<T>) {
fn from(from: PgRange<T>) -> Self {
(from.start, from.end)
}
}
impl<T> From<Range<T>> for PgRange<T> {
fn from(from: Range<T>) -> Self {
Self {
start: Bound::Included(from.start),
end: Bound::Excluded(from.end),
}
}
}
impl<T> From<RangeFrom<T>> for PgRange<T> {
fn from(from: RangeFrom<T>) -> Self {
Self {
start: Bound::Included(from.start),
end: Bound::Unbounded,
}
}
}
impl<T> From<RangeInclusive<T>> for PgRange<T> {
fn from(from: RangeInclusive<T>) -> Self {
let (start, end) = from.into_inner();
Self {
start: Bound::Included(start),
end: Bound::Excluded(end),
}
}
}
impl<T> From<RangeTo<T>> for PgRange<T> {
fn from(from: RangeTo<T>) -> Self {
Self {
start: Bound::Unbounded,
end: Bound::Excluded(from.end),
}
}
}
impl<T> From<RangeToInclusive<T>> for PgRange<T> {
fn from(from: RangeToInclusive<T>) -> Self {
Self {
start: Bound::Unbounded,
end: Bound::Included(from.end),
}
}
}
impl<T> RangeBounds<T> for PgRange<T> {
fn start_bound(&self) -> Bound<&T> {
match &self.start {
Bound::Included(ref start) => Bound::Included(start),
Bound::Excluded(ref start) => Bound::Excluded(start),
Bound::Unbounded => Bound::Unbounded,
}
}
fn end_bound(&self) -> Bound<&T> {
match &self.end {
Bound::Included(ref end) => Bound::Included(end),
Bound::Excluded(ref end) => Bound::Excluded(end),
Bound::Unbounded => Bound::Unbounded,
}
}
}
impl<T> TryFrom<PgRange<T>> for Range<T> {
type Error = crate::error::Error;
fn try_from(from: PgRange<T>) -> crate::error::Result<Self> {
let err_msg = "Invalid data for core::ops::Range";
let start = included(from.start, err_msg)?;
let end = excluded(from.end, err_msg)?;
Ok(start..end)
}
}
impl<T> TryFrom<PgRange<T>> for RangeFrom<T> {
type Error = crate::error::Error;
fn try_from(from: PgRange<T>) -> crate::error::Result<Self> {
let err_msg = "Invalid data for core::ops::RangeFrom";
let start = included(from.start, err_msg)?;
unbounded(from.end, err_msg)?;
Ok(start..)
}
}
impl<T> TryFrom<PgRange<T>> for RangeInclusive<T> {
type Error = crate::error::Error;
fn try_from(from: PgRange<T>) -> crate::error::Result<Self> {
let err_msg = "Invalid data for core::ops::RangeInclusive";
let start = included(from.start, err_msg)?;
let end = included(from.end, err_msg)?;
Ok(start..=end)
}
}
impl<T> TryFrom<PgRange<T>> for RangeTo<T> {
type Error = crate::error::Error;
fn try_from(from: PgRange<T>) -> crate::error::Result<Self> {
let err_msg = "Invalid data for core::ops::RangeTo";
unbounded(from.start, err_msg)?;
let end = excluded(from.end, err_msg)?;
Ok(..end)
}
}
impl<T> TryFrom<PgRange<T>> for RangeToInclusive<T> {
type Error = crate::error::Error;
fn try_from(from: PgRange<T>) -> crate::error::Result<Self> {
let err_msg = "Invalid data for core::ops::RangeToInclusive";
unbounded(from.start, err_msg)?;
let end = included(from.end, err_msg)?;
Ok(..=end)
}
}
fn decode_binary<'r, T>(
mut bytes: &[u8],
format: PgValueFormat,
type_info: PgTypeInfo,
) -> Result<PgRange<T>, crate::error::BoxDynError>
where
T: for<'rec> Decode<'rec, Postgres> + 'r,
{
let flags: RangeFlags = RangeFlags::from_bits_truncate(bytes.read_u8()?);
let mut start_value = Bound::Unbounded;
let mut end_value = Bound::Unbounded;
if flags.contains(RangeFlags::EMPTY) {
return Ok(PgRange {
start: start_value,
end: end_value,
});
}
if !flags.contains(RangeFlags::LB_INF) {
let elem_size = bytes.read_i32::<NetworkEndian>()?;
let (elem_bytes, new_bytes) = bytes.split_at(elem_size as usize);
bytes = new_bytes;
let value = T::decode(PgValueRef {
type_info: type_info.clone(),
format,
value: Some(elem_bytes),
row: None,
})?;
start_value = if flags.contains(RangeFlags::LB_INC) {
Bound::Included(value)
} else {
Bound::Excluded(value)
};
}
if !flags.contains(RangeFlags::UB_INF) {
bytes.read_i32::<NetworkEndian>()?;
let value = T::decode(PgValueRef {
type_info,
format,
value: Some(bytes),
row: None,
})?;
end_value = if flags.contains(RangeFlags::UB_INC) {
Bound::Included(value)
} else {
Bound::Excluded(value)
};
}
Ok(PgRange {
start: start_value,
end: end_value,
})
}
fn decode_str<'r, T>(
s: &str,
format: PgValueFormat,
type_info: PgTypeInfo,
) -> Result<PgRange<T>, crate::error::BoxDynError>
where
T: for<'rec> Decode<'rec, Postgres> + 'r,
{
let err = || crate::error::Error::Decode("Invalid PostgreSQL range string".into());
let value =
|bound: &str, delim, bounds: [&str; 2]| -> Result<Bound<T>, crate::error::BoxDynError> {
if bound.len() == 0 {
return Ok(Bound::Unbounded);
}
let bound_value = T::decode(PgValueRef {
type_info: type_info.clone(),
format,
value: Some(bound.as_bytes()),
row: None,
})?;
if delim == bounds[0] {
Ok(Bound::Excluded(bound_value))
} else if delim == bounds[1] {
Ok(Bound::Included(bound_value))
} else {
Err(Box::new(err()))
}
};
let mut parts = s.split(',');
let start_str = parts.next().ok_or_else(err)?;
let start_value = value(
start_str.get(1..).ok_or_else(err)?,
start_str.get(0..1).ok_or_else(err)?,
["(", "["],
)?;
let end_str = parts.next().ok_or_else(err)?;
let last_char_idx = end_str.len() - 1;
let end_value = value(
end_str.get(..last_char_idx).ok_or_else(err)?,
end_str.get(last_char_idx..).ok_or_else(err)?,
[")", "]"],
)?;
Ok(PgRange {
start: start_value,
end: end_value,
})
}
fn excluded<T>(b: Bound<T>, err_msg: &str) -> crate::error::Result<T> {
if let Bound::Excluded(rslt) = b {
Ok(rslt)
} else {
Err(crate::error::Error::Decode(err_msg.into()))
}
}
fn included<T>(b: Bound<T>, err_msg: &str) -> crate::error::Result<T> {
if let Bound::Included(rslt) = b {
Ok(rslt)
} else {
Err(crate::error::Error::Decode(err_msg.into()))
}
}
fn unbounded<T>(b: Bound<T>, err_msg: &str) -> crate::error::Result<()> {
if matches!(b, Bound::Unbounded) {
Ok(())
} else {
Err(crate::error::Error::Decode(err_msg.into()))
}
}

View File

@ -0,0 +1,84 @@
use crate::{
decode::Decode,
encode::{Encode, IsNull},
postgres::{
types::ranges::pg_range::PgRange, PgArgumentBuffer, PgTypeInfo, PgValueRef, Postgres,
},
types::Type,
};
macro_rules! impl_pg_range {
($range_name:ident, $type_info:expr, $type_info_array:expr, $range_type:ty) => {
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
#[repr(transparent)]
pub struct $range_name(pub PgRange<$range_type>);
impl<'a> Decode<'a, Postgres> for $range_name {
fn accepts(ty: &PgTypeInfo) -> bool {
<PgRange<$range_type> as Decode<'_, Postgres>>::accepts(ty)
}
fn decode(value: PgValueRef<'a>) -> Result<$range_name, crate::error::BoxDynError> {
Ok(Self(Decode::<Postgres>::decode(value)?))
}
}
impl<'a> Encode<'a, Postgres> for $range_name {
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull {
<PgRange<$range_type> as Encode<'_, Postgres>>::encode_by_ref(&self.0, buf)
}
}
impl Type<Postgres> for $range_name {
fn type_info() -> PgTypeInfo {
$type_info
}
}
impl Type<Postgres> for [$range_name] {
fn type_info() -> PgTypeInfo {
$type_info_array
}
}
impl Type<Postgres> for Vec<$range_name> {
fn type_info() -> PgTypeInfo {
$type_info_array
}
}
};
}
impl_pg_range!(
Int4Range,
PgTypeInfo::INT4_RANGE,
PgTypeInfo::INT4_RANGE_ARRAY,
i32
);
#[cfg(feature = "bigdecimal")]
impl_pg_range!(
NumRange,
PgTypeInfo::NUM_RANGE,
PgTypeInfo::NUM_RANGE_ARRAY,
bigdecimal::BigDecimal
);
#[cfg(feature = "chrono")]
impl_pg_range!(
TsRange,
PgTypeInfo::TS_RANGE,
PgTypeInfo::TS_RANGE_ARRAY,
chrono::NaiveDateTime
);
#[cfg(feature = "chrono")]
impl_pg_range!(
DateRange,
PgTypeInfo::DATE_RANGE,
PgTypeInfo::DATE_RANGE_ARRAY,
chrono::NaiveDate
);
impl_pg_range!(
Int8Range,
PgTypeInfo::INT8_RANGE,
PgTypeInfo::INT8_RANGE_ARRAY,
i64
);

View File

@ -7,6 +7,7 @@ use crate::postgres::type_info::PgType;
use crate::postgres::{
PgArgumentBuffer, PgTypeInfo, PgTypeKind, PgValueFormat, PgValueRef, Postgres,
};
use crate::types::Type;
#[doc(hidden)]
pub struct PgRecordEncoder<'a> {
@ -36,7 +37,7 @@ impl<'a> PgRecordEncoder<'a> {
pub fn encode<'q, T>(&mut self, value: T) -> &mut Self
where
'a: 'q,
T: Encode<'q, Postgres>,
T: Encode<'q, Postgres> + Type<Postgres>,
{
let ty = T::type_info();
@ -101,7 +102,7 @@ impl<'r> PgRecordDecoder<'r> {
#[doc(hidden)]
pub fn try_decode<T>(&mut self) -> Result<T, BoxDynError>
where
T: for<'a> Decode<'a, Postgres>,
T: for<'a> Decode<'a, Postgres> + Type<Postgres>,
{
if self.buf.is_empty() {
return Err(format!("no field `{0}` found on record", self.ind).into());

View File

@ -28,12 +28,20 @@ impl Encode<'_, Postgres> for &'_ str {
IsNull::No
}
fn produces(&self) -> Option<PgTypeInfo> {
<Self as Type<Postgres>>::type_info().into()
}
}
impl Encode<'_, Postgres> for String {
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull {
<&str as Encode<Postgres>>::encode(&**self, buf)
}
fn produces(&self) -> Option<PgTypeInfo> {
<Self as Type<Postgres>>::type_info().into()
}
}
impl<'r> Decode<'r, Postgres> for &'r str {

View File

@ -90,12 +90,20 @@ impl Encode<'_, Postgres> for Time {
Encode::<Postgres>::encode(&us, buf)
}
fn produces(&self) -> Option<PgTypeInfo> {
<Self as Type<Postgres>>::type_info().into()
}
fn size_hint(&self) -> usize {
mem::size_of::<u64>()
}
}
impl<'r> Decode<'r, Postgres> for Time {
fn accepts(ty: &PgTypeInfo) -> bool {
*ty == <Self as Type<Postgres>>::type_info()
}
fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
Ok(match value.format() {
PgValueFormat::Binary => {
@ -131,12 +139,20 @@ impl Encode<'_, Postgres> for Date {
Encode::<Postgres>::encode(&days, buf)
}
fn produces(&self) -> Option<PgTypeInfo> {
<Self as Type<Postgres>>::type_info().into()
}
fn size_hint(&self) -> usize {
mem::size_of::<i32>()
}
}
impl<'r> Decode<'r, Postgres> for Date {
fn accepts(ty: &PgTypeInfo) -> bool {
*ty == <Self as Type<Postgres>>::type_info()
}
fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
Ok(match value.format() {
PgValueFormat::Binary => {
@ -157,12 +173,20 @@ impl Encode<'_, Postgres> for PrimitiveDateTime {
Encode::<Postgres>::encode(&us, buf)
}
fn produces(&self) -> Option<PgTypeInfo> {
<Self as Type<Postgres>>::type_info().into()
}
fn size_hint(&self) -> usize {
mem::size_of::<i64>()
}
}
impl<'r> Decode<'r, Postgres> for PrimitiveDateTime {
fn accepts(ty: &PgTypeInfo) -> bool {
*ty == <Self as Type<Postgres>>::type_info()
}
fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
Ok(match value.format() {
PgValueFormat::Binary => {
@ -214,12 +238,20 @@ impl Encode<'_, Postgres> for OffsetDateTime {
Encode::<Postgres>::encode(&primitive, buf)
}
fn produces(&self) -> Option<PgTypeInfo> {
<Self as Type<Postgres>>::type_info().into()
}
fn size_hint(&self) -> usize {
mem::size_of::<i64>()
}
}
impl<'r> Decode<'r, Postgres> for OffsetDateTime {
fn accepts(ty: &PgTypeInfo) -> bool {
*ty == <Self as Type<Postgres>>::type_info()
}
fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
Ok(<PrimitiveDateTime as Decode<Postgres>>::decode(value)?.assume_utc())
}

View File

@ -33,6 +33,10 @@ macro_rules! impl_type_for_tuple {
$($T: Type<Postgres>,)*
$($T: for<'a> Decode<'a, Postgres>,)*
{
fn accepts(ty: &PgTypeInfo) -> bool {
*ty == <Self as Type<Postgres>>::type_info()
}
fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
#[allow(unused)]
let mut decoder = PgRecordDecoder::new(value)?;

View File

@ -30,9 +30,17 @@ impl Encode<'_, Postgres> for Uuid {
IsNull::No
}
fn produces(&self) -> Option<PgTypeInfo> {
<Self as Type<Postgres>>::type_info().into()
}
}
impl Decode<'_, Postgres> for Uuid {
fn accepts(ty: &PgTypeInfo) -> bool {
*ty == <Self as Type<Postgres>>::type_info()
}
fn decode(value: PgValueRef<'_>) -> Result<Self, BoxDynError> {
match value.format() {
PgValueFormat::Binary => Uuid::from_slice(value.as_bytes()?),

View File

@ -9,6 +9,7 @@ use crate::database::{Database, HasArguments};
use crate::encode::Encode;
use crate::error::Error;
use crate::executor::{Execute, Executor};
use crate::types::Type;
/// Raw SQL query with bind parameters. Returned by [`query`][crate::query::query].
#[must_use = "query must be executed to affect database"]
@ -57,7 +58,7 @@ impl<'q, DB: Database> Query<'q, DB, <DB as HasArguments<'q>>::Arguments> {
///
/// There is no validation that the value is of the type expected by the query. Most SQL
/// flavors will perform type coercion (Postgres will return a database error).
pub fn bind<T: 'q + Encode<'q, DB>>(mut self, value: T) -> Self {
pub fn bind<T: 'q + Encode<'q, DB> + Type<DB>>(mut self, value: T) -> Self {
if let Some(arguments) = &mut self.arguments {
arguments.add(value);
}

View File

@ -11,6 +11,7 @@ use crate::error::Error;
use crate::executor::{Execute, Executor};
use crate::from_row::FromRow;
use crate::query::{query, query_with, Query};
use crate::types::Type;
/// Raw SQL query with bind parameters, mapped to a concrete type using [`FromRow`].
/// Returned from [`query_as`].
@ -40,7 +41,7 @@ impl<'q, DB: Database, O> QueryAs<'q, DB, O, <DB as HasArguments<'q>>::Arguments
/// Bind a value for use with this SQL query.
///
/// See [`Query::bind`](crate::query::Query::bind).
pub fn bind<T: 'q + Encode<'q, DB>>(mut self, value: T) -> Self {
pub fn bind<T: 'q + Encode<'q, DB> + Type<DB>>(mut self, value: T) -> Self {
self.inner = self.inner.bind(value);
self
}

View File

@ -9,6 +9,7 @@ use crate::error::Error;
use crate::executor::{Execute, Executor};
use crate::from_row::FromRow;
use crate::query_as::{query_as, query_as_with, QueryAs};
use crate::types::Type;
/// Raw SQL query with bind parameters, mapped to a concrete type using [`FromRow`] on `(O,)`.
/// Returned from [`query_scalar`].
@ -36,7 +37,7 @@ impl<'q, DB: Database, O> QueryScalar<'q, DB, O, <DB as HasArguments<'q>>::Argum
/// Bind a value for use with this SQL query.
///
/// See [`Query::bind`](crate::query::Query::bind).
pub fn bind<T: 'q + Encode<'q, DB>>(mut self, value: T) -> Self {
pub fn bind<T: 'q + Encode<'q, DB> + Type<DB>>(mut self, value: T) -> Self {
self.inner = self.inner.bind(value);
self
}

View File

@ -3,6 +3,7 @@ use std::fmt::Debug;
use crate::database::{Database, HasValueRef};
use crate::decode::Decode;
use crate::error::{mismatched_types, Error};
use crate::types::Type;
use crate::value::ValueRef;
/// A type that can be used to index into a [`Row`].
@ -89,7 +90,7 @@ pub trait Row: private_row::Sealed + Unpin + Send + Sync + 'static {
fn get<'r, T, I>(&'r self, index: I) -> T
where
I: ColumnIndex<Self>,
T: Decode<'r, Self::Database>,
T: Decode<'r, Self::Database> + Type<Self::Database>,
{
self.try_get::<T, I>(index).unwrap()
}
@ -132,7 +133,7 @@ pub trait Row: private_row::Sealed + Unpin + Send + Sync + 'static {
fn try_get<'r, T, I>(&'r self, index: I) -> Result<T, Error>
where
I: ColumnIndex<Self>,
T: Decode<'r, Self::Database>,
T: Decode<'r, Self::Database> + Type<Self::Database>,
{
let value = self.try_get_raw(&index)?;

View File

@ -17,6 +17,10 @@ impl<'q> Encode<'q, Sqlite> for bool {
IsNull::No
}
fn produces(&self) -> Option<SqliteTypeInfo> {
<Self as Type<Sqlite>>::type_info().into()
}
}
impl<'r> Decode<'r, Sqlite> for bool {

View File

@ -19,6 +19,10 @@ impl<'q> Encode<'q, Sqlite> for &'q [u8] {
IsNull::No
}
fn produces(&self) -> Option<SqliteTypeInfo> {
<Self as Type<Sqlite>>::type_info().into()
}
}
impl<'r> Decode<'r, Sqlite> for &'r [u8] {
@ -49,6 +53,10 @@ impl<'q> Encode<'q, Sqlite> for Vec<u8> {
IsNull::No
}
fn produces(&self) -> Option<SqliteTypeInfo> {
<Self as Type<Sqlite>>::type_info().into()
}
}
impl<'r> Decode<'r, Sqlite> for Vec<u8> {

View File

@ -17,6 +17,10 @@ impl<'q> Encode<'q, Sqlite> for f32 {
IsNull::No
}
fn produces(&self) -> Option<SqliteTypeInfo> {
<Self as Type<Sqlite>>::type_info().into()
}
}
impl<'r> Decode<'r, Sqlite> for f32 {
@ -41,6 +45,10 @@ impl<'q> Encode<'q, Sqlite> for f64 {
IsNull::No
}
fn produces(&self) -> Option<SqliteTypeInfo> {
<Self as Type<Sqlite>>::type_info().into()
}
}
impl<'r> Decode<'r, Sqlite> for f64 {

View File

@ -17,6 +17,10 @@ impl<'q> Encode<'q, Sqlite> for i32 {
IsNull::No
}
fn produces(&self) -> Option<SqliteTypeInfo> {
<Self as Type<Sqlite>>::type_info().into()
}
}
impl<'r> Decode<'r, Sqlite> for i32 {
@ -41,6 +45,10 @@ impl<'q> Encode<'q, Sqlite> for i64 {
IsNull::No
}
fn produces(&self) -> Option<SqliteTypeInfo> {
<Self as Type<Sqlite>>::type_info().into()
}
}
impl<'r> Decode<'r, Sqlite> for i64 {

View File

@ -19,6 +19,10 @@ impl<'q> Encode<'q, Sqlite> for &'q str {
IsNull::No
}
fn produces(&self) -> Option<SqliteTypeInfo> {
<Self as Type<Sqlite>>::type_info().into()
}
}
impl<'r> Decode<'r, Sqlite> for &'r str {
@ -49,6 +53,10 @@ impl<'q> Encode<'q, Sqlite> for String {
IsNull::No
}
fn produces(&self) -> Option<SqliteTypeInfo> {
<Self as Type<Sqlite>>::type_info().into()
}
}
impl<'r> Decode<'r, Sqlite> for String {

View File

@ -45,6 +45,10 @@ where
fn encode_by_ref(&self, buf: &mut <DB as HasArguments<'q>>::ArgumentBuffer) -> IsNull {
<Json<&Self> as Encode<'q, DB>>::encode(Json(self), buf)
}
fn produces(&self) -> Option<DB::TypeInfo> {
<Self as Type<DB>>::type_info().into()
}
}
impl<'r, DB> Decode<'r, DB> for JsonValue

View File

@ -3,6 +3,7 @@ use std::borrow::Cow;
use crate::database::{Database, HasValueRef};
use crate::decode::Decode;
use crate::error::{mismatched_types, Error};
use crate::types::Type;
/// An owned value from the database.
pub trait Value {
@ -30,7 +31,7 @@ pub trait Value {
#[inline]
fn decode<'r, T>(&'r self) -> T
where
T: Decode<'r, Self::Database>,
T: Decode<'r, Self::Database> + Type<Self::Database>,
{
self.try_decode::<T>().unwrap()
}
@ -64,7 +65,7 @@ pub trait Value {
#[inline]
fn try_decode<'r, T>(&'r self) -> Result<T, Error>
where
T: Decode<'r, Self::Database>,
T: Decode<'r, Self::Database> + Type<Self::Database>,
{
if !self.is_null() {
if let Some(actual_ty) = self.type_info() {

View File

@ -71,6 +71,10 @@ fn expand_derive_decode_transparent(
let tts = quote!(
impl #impl_generics sqlx::decode::Decode<'de, DB> for #ident #ty_generics #where_clause {
fn accepts(ty: &DB::TypeInfo) -> bool {
<#ty as sqlx::decode::Decode<'de, DB>>::accepts(ty)
}
fn decode(value: <DB as sqlx::database::HasValueRef<'de>>::ValueRef) -> std::result::Result<Self, sqlx::BoxDynError> {
<#ty as sqlx::decode::Decode<'de, DB>>::decode(value).map(Self)
}
@ -100,6 +104,10 @@ fn expand_derive_decode_weak_enum(
Ok(quote!(
impl<'de, DB: sqlx::Database> sqlx::decode::Decode<'de, DB> for #ident where #repr: sqlx::decode::Decode<'de, DB> {
fn accepts(ty: &MySqlTypeInfo) -> bool {
*ty == Self::type_info()
}
fn decode(value: <DB as sqlx::database::HasValueRef<'de>>::ValueRef) -> std::result::Result<Self, sqlx::BoxDynError> {
let value = <#repr as sqlx::decode::Decode<'de, DB>>::decode(value)?;
@ -140,6 +148,10 @@ fn expand_derive_decode_strong_enum(
Ok(quote!(
impl<'de, DB: sqlx::Database> sqlx::decode::Decode<'de, DB> for #ident where &'de str: sqlx::decode::Decode<'de, DB> {
fn accepts(ty: &MySqlTypeInfo) -> bool {
*ty == Self::type_info()
}
fn decode(value: <DB as sqlx::database::HasValueRef<'de>>::ValueRef) -> std::result::Result<Self, sqlx::BoxDynError> {
let value = <&'de str as sqlx::decode::Decode<'de, DB>>::decode(value)?;
match value {
@ -195,6 +207,10 @@ fn expand_derive_decode_struct(
tts.extend(quote!(
impl #impl_generics sqlx::decode::Decode<'de, sqlx::Postgres> for #ident #ty_generics #where_clause {
fn accepts(ty: &MySqlTypeInfo) -> bool {
*ty == Self::type_info()
}
fn decode(value: <sqlx::Postgres as sqlx::value::HasRawValue<'de>>::RawValue) -> sqlx::Result<Self> {
let mut decoder = sqlx::postgres::types::raw::PgRecordDecoder::new(value)?;

View File

@ -83,7 +83,9 @@ fn expand_derive_encode_transparent(
fn encode_by_ref(&self, buf: &mut <DB as sqlx::database::HasArguments<#lifetime>>::ArgumentBuffer) -> sqlx::encode::IsNull {
sqlx::encode::Encode::encode_by_ref(&self.0, buf)
}
fn produces(&self) -> Option<DB::TypeInfo> {
<#ty as sqlx::encode::Encode<DB>>::produces(&self.0)
}
fn size_hint(&self) -> usize {
sqlx::encode::Encode::size_hint(&self.0)
}
@ -110,11 +112,15 @@ fn expand_derive_encode_weak_enum(
sqlx::encode::Encode::encode_by_ref(&(*self as #repr), buf)
}
fn size_hint(&self) -> usize {
sqlx::encode::Encode::size_hint(&(*self as #repr))
fn produces(&self) -> Option<DB::TypeInfo> {
<Self as Type<MySql>>::type_info().into()
}
fn size_hint(&self) -> usize {
sqlx::encode::Encode::size_hint(&(*self as #repr))
}
}
}
))
))
}
fn expand_derive_encode_strong_enum(
@ -152,6 +158,10 @@ fn expand_derive_encode_strong_enum(
<str as sqlx::encode::Encode<'q, DB>>::encode_by_ref(val, buf)
}
fn produces(&self) -> Option<DB::TypeInfo> {
<Self as Type<MySql>>::type_info().into()
}
fn size_hint(&self) -> usize {
let val = match self {
#(#value_arms)*
@ -220,6 +230,10 @@ fn expand_derive_encode_struct(
encoder.finish()
}
fn produces(&self) -> Option<DB::TypeInfo> {
<Self as Type<MySql>>::type_info().into()
}
fn size_hint(&self) -> usize {
#column_count * (4 + 4) // oid (int) and length (int) for each column
+ #(#sizes)+* // sum of the size hints for each column

View File

@ -333,3 +333,36 @@ test_type!(decimal<sqlx::types::BigDecimal>(Postgres,
"12.34::numeric" == "12.34".parse::<sqlx::types::BigDecimal>().unwrap(),
"12345.6789::numeric" == "12345.6789".parse::<sqlx::types::BigDecimal>().unwrap(),
));
mod ranges {
use super::*;
use core::ops::Bound;
use sqlx::postgres::types::{Int4Range, PgRange};
const EXC2: Bound<i32> = Bound::Excluded(2);
const EXC3: Bound<i32> = Bound::Excluded(3);
const INC1: Bound<i32> = Bound::Included(1);
const INC2: Bound<i32> = Bound::Included(2);
const UNB: Bound<i32> = Bound::Unbounded;
// int4range display is hard-coded into [l, u)
test_type!(int4range<PgRange<i32>>(Postgres,
"'(,)'::int4range" == Int4Range(PgRange::new([UNB, UNB])),
"'(,]'::int4range" == Int4Range(PgRange::new([UNB, UNB])),
"'(,2)'::int4range" == Int4Range(PgRange::new([UNB, EXC2])),
"'(,2]'::int4range" == Int4Range(PgRange::new([UNB, EXC3])),
"'(1,)'::int4range" == Int4Range(PgRange::new([INC2, UNB])),
"'(1,]'::int4range" == Int4Range(PgRange::new([INC2, UNB])),
"'(1,2]'::int4range" == Int4Range(PgRange::new([INC2, EXC3])),
"'[,)'::int4range" == Int4Range(PgRange::new([UNB, UNB])),
"'[,]'::int4range" == Int4Range(PgRange::new([UNB, UNB])),
"'[,2)'::int4range" == Int4Range(PgRange::new([UNB, EXC2])),
"'[,2]'::int4range" == Int4Range(PgRange::new([UNB, EXC3])),
"'[1,)'::int4range" == Int4Range(PgRange::new([INC1, UNB])),
"'[1,]'::int4range" == Int4Range(PgRange::new([INC1, UNB])),
"'[1,2)'::int4range" == Int4Range(PgRange::new([INC1, EXC2])),
"'[1,2]'::int4range" == Int4Range(PgRange::new([INC1, EXC3])),
));
}