diff --git a/sqlx-core/src/decode.rs b/sqlx-core/src/decode.rs index dd2f3606..82e14d0d 100644 --- a/sqlx-core/src/decode.rs +++ b/sqlx-core/src/decode.rs @@ -1,29 +1,11 @@ //! Provides [`Decode`](trait.Decode.html) for decoding values from the database. -use std::result::Result as StdResult; - use crate::database::{Database, HasValueRef}; use crate::error::BoxDynError; -use crate::types::Type; use crate::value::ValueRef; -/// A specialized result type representing the result of decoding a value from the database. -pub type Result = StdResult; - /// A type that can be decoded from the database. /// -/// ## Derivable -/// -/// This trait can be derived to provide user-defined types where supported by -/// the database driver. -/// -/// ```rust,ignore -/// // `UserId` can now be decoded from the database where -/// // an `i64` was expected. -/// #[derive(Decode)] -/// struct UserId(i64); -/// ``` -/// /// ## How can I implement `Decode`? /// /// A manual implementation of `Decode` can be useful when adding support for @@ -58,14 +40,6 @@ pub type Result = StdResult; /// // are supported by the database /// &'r str: Decode<'r, DB> /// { -/// fn accepts(ty: &DB::TypeInfo) -> bool { -/// // accepts is intended to provide runtime type checking and assert that our decode -/// // function can handle the incoming value from the database -/// -/// // as we are delegating to String -/// <&str as Decode>::accepts(ty) -/// } -/// /// fn decode( /// value: >::ValueRef, /// ) -> Result> { @@ -84,12 +58,8 @@ pub type Result = StdResult; /// } /// ``` pub trait Decode<'r, DB: Database>: Sized { - /// Determines if a value of this type can be created from a value with the - /// given type information. - fn accepts(ty: &DB::TypeInfo) -> bool; - /// Decode a new value of this type using a raw value from the database. - fn decode(value: >::ValueRef) -> Result; + fn decode(value: >::ValueRef) -> Result; } // implement `Decode` for Option for all SQL types @@ -98,11 +68,7 @@ where DB: Database, T: Decode<'r, DB>, { - fn accepts(ty: &DB::TypeInfo) -> bool { - T::accepts(ty) - } - - fn decode(value: >::ValueRef) -> Result { + fn decode(value: >::ValueRef) -> Result { if value.is_null() { Ok(None) } else { @@ -110,10 +76,3 @@ where } } } - -// default implementation of `accepts` -// this can be trivially removed once min_specialization is stable -#[allow(dead_code)] -pub(crate) fn accepts>(ty: &DB::TypeInfo) -> bool { - *ty == T::type_info() -} diff --git a/sqlx-core/src/error.rs b/sqlx-core/src/error.rs index 7e436eba..12959e6e 100644 --- a/sqlx-core/src/error.rs +++ b/sqlx-core/src/error.rs @@ -8,6 +8,8 @@ use std::io; use std::result::Result as StdResult; use crate::database::Database; +use crate::type_info::TypeInfo; +use crate::types::Type; /// A specialized `Result` type for SQLx. pub type Result = StdResult; @@ -119,12 +121,12 @@ impl Error { } } -pub(crate) fn mismatched_types(expected: &DB::TypeInfo) -> BoxDynError { - let ty_name = type_name::(); - +pub(crate) fn mismatched_types>(ty: &DB::TypeInfo) -> BoxDynError { format!( - "mismatched types; Rust type `{}` is not compatible with SQL type `{}`", - ty_name, expected + "mismatched types; Rust type `{}` (as SQL type `{}`) is not compatible with SQL type `{}`", + type_name::(), + T::type_info().name(), + ty.name() ) .into() } diff --git a/sqlx-core/src/ext/ustr.rs b/sqlx-core/src/ext/ustr.rs index 29312fa2..fe677a20 100644 --- a/sqlx-core/src/ext/ustr.rs +++ b/sqlx-core/src/ext/ustr.rs @@ -8,7 +8,7 @@ use std::sync::Arc; // a micro-string is either a reference-counted string or a static string // this guarantees these are cheap to clone everywhere #[derive(Debug, Clone, Eq)] -pub(crate) enum UStr { +pub enum UStr { Static(&'static str), Shared(Arc), } diff --git a/sqlx-core/src/from_row.rs b/sqlx-core/src/from_row.rs index 3de02c86..52575c57 100644 --- a/sqlx-core/src/from_row.rs +++ b/sqlx-core/src/from_row.rs @@ -5,7 +5,7 @@ use crate::row::Row; /// /// In order to use [`query_as`] the output type must implement `FromRow`. /// -/// # Deriving +/// ## Derivable /// /// This trait can be automatically derived by SQLx for any struct. The generated implementation /// will consist of a sequence of calls to [`Row::try_get`] using the name from each diff --git a/sqlx-core/src/mssql/protocol/type_info.rs b/sqlx-core/src/mssql/protocol/type_info.rs index 920d33c7..ef27e494 100644 --- a/sqlx-core/src/mssql/protocol/type_info.rs +++ b/sqlx-core/src/mssql/protocol/type_info.rs @@ -482,6 +482,43 @@ impl TypeInfo { buf[offset..(offset + 4)].copy_from_slice(&size.to_le_bytes()); } + pub(crate) fn name(&self) -> &'static str { + match self.ty { + DataType::Null => "NULL", + DataType::TinyInt => "TINYINT", + DataType::SmallInt => "SMALLINT", + DataType::Int => "INT", + DataType::BigInt => "BIGINT", + DataType::Real => "REAL", + DataType::Float => "FLOAT", + + DataType::IntN => match self.size { + 1 => "TINYINT", + 2 => "SMALLINT", + 4 => "INT", + 8 => "BIGINT", + + _ => unreachable!("invalid size {} for int"), + }, + + DataType::FloatN => match self.size { + 4 => "REAL", + 8 => "FLOAT", + + _ => unreachable!("invalid size {} for float"), + }, + + DataType::VarChar => "VARCHAR", + DataType::NVarChar => "NVARCHAR", + DataType::BigVarChar => "BIGVARCHAR", + DataType::Char => "CHAR", + DataType::BigChar => "BIGCHAR", + DataType::NChar => "NCHAR", + + _ => unimplemented!("name: unsupported data type {:?}", self.ty), + } + } + pub(crate) fn fmt(&self, s: &mut String) { match self.ty { DataType::Null => s.push_str("nvarchar(1)"), diff --git a/sqlx-core/src/mssql/type_info.rs b/sqlx-core/src/mssql/type_info.rs index 6389b4ff..d39b134d 100644 --- a/sqlx-core/src/mssql/type_info.rs +++ b/sqlx-core/src/mssql/type_info.rs @@ -7,13 +7,14 @@ use crate::type_info::TypeInfo; #[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] pub struct MssqlTypeInfo(pub(crate) ProtocolTypeInfo); -impl TypeInfo for MssqlTypeInfo {} +impl TypeInfo for MssqlTypeInfo { + fn name(&self) -> &str { + self.0.name() + } +} impl Display for MssqlTypeInfo { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - let mut buf = String::new(); - self.0.fmt(&mut buf); - - f.pad(&*buf) + f.pad(self.name()) } } diff --git a/sqlx-core/src/mssql/types/float.rs b/sqlx-core/src/mssql/types/float.rs index 97f52fca..c4e2c337 100644 --- a/sqlx-core/src/mssql/types/float.rs +++ b/sqlx-core/src/mssql/types/float.rs @@ -11,6 +11,10 @@ impl Type for f32 { fn type_info() -> MssqlTypeInfo { MssqlTypeInfo(TypeInfo::new(DataType::FloatN, 4)) } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + matches!(ty.0.ty, DataType::Real | DataType::FloatN) && ty.0.size == 4 + } } impl Encode<'_, Mssql> for f32 { @@ -22,10 +26,6 @@ impl Encode<'_, Mssql> for f32 { } impl Decode<'_, Mssql> for f32 { - fn accepts(ty: &MssqlTypeInfo) -> bool { - matches!(ty.0.ty, DataType::Real | DataType::FloatN) && ty.0.size == 4 - } - fn decode(value: MssqlValueRef<'_>) -> Result { Ok(LittleEndian::read_f32(value.as_bytes()?)) } @@ -35,6 +35,10 @@ impl Type for f64 { fn type_info() -> MssqlTypeInfo { MssqlTypeInfo(TypeInfo::new(DataType::FloatN, 8)) } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + matches!(ty.0.ty, DataType::Float | DataType::FloatN) && ty.0.size == 8 + } } impl Encode<'_, Mssql> for f64 { @@ -46,10 +50,6 @@ impl Encode<'_, Mssql> for f64 { } impl Decode<'_, Mssql> for f64 { - fn accepts(ty: &MssqlTypeInfo) -> bool { - matches!(ty.0.ty, DataType::Float | DataType::FloatN) && ty.0.size == 8 - } - fn decode(value: MssqlValueRef<'_>) -> Result { Ok(LittleEndian::read_f64(value.as_bytes()?)) } diff --git a/sqlx-core/src/mssql/types/int.rs b/sqlx-core/src/mssql/types/int.rs index 3a0c5002..1734e127 100644 --- a/sqlx-core/src/mssql/types/int.rs +++ b/sqlx-core/src/mssql/types/int.rs @@ -11,6 +11,10 @@ impl Type for i8 { fn type_info() -> MssqlTypeInfo { MssqlTypeInfo(TypeInfo::new(DataType::IntN, 1)) } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + matches!(ty.0.ty, DataType::TinyInt | DataType::IntN) && ty.0.size == 1 + } } impl Encode<'_, Mssql> for i8 { @@ -22,10 +26,6 @@ impl Encode<'_, Mssql> for i8 { } impl Decode<'_, Mssql> for i8 { - fn accepts(ty: &MssqlTypeInfo) -> bool { - matches!(ty.0.ty, DataType::TinyInt | DataType::IntN) && ty.0.size == 1 - } - fn decode(value: MssqlValueRef<'_>) -> Result { Ok(value.as_bytes()?[0] as i8) } @@ -35,6 +35,10 @@ impl Type for i16 { fn type_info() -> MssqlTypeInfo { MssqlTypeInfo(TypeInfo::new(DataType::IntN, 2)) } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + matches!(ty.0.ty, DataType::SmallInt | DataType::IntN) && ty.0.size == 2 + } } impl Encode<'_, Mssql> for i16 { @@ -46,10 +50,6 @@ impl Encode<'_, Mssql> for i16 { } impl Decode<'_, Mssql> for i16 { - fn accepts(ty: &MssqlTypeInfo) -> bool { - matches!(ty.0.ty, DataType::SmallInt | DataType::IntN) && ty.0.size == 2 - } - fn decode(value: MssqlValueRef<'_>) -> Result { Ok(LittleEndian::read_i16(value.as_bytes()?)) } @@ -59,6 +59,10 @@ impl Type for i32 { fn type_info() -> MssqlTypeInfo { MssqlTypeInfo(TypeInfo::new(DataType::IntN, 4)) } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + matches!(ty.0.ty, DataType::Int | DataType::IntN) && ty.0.size == 4 + } } impl Encode<'_, Mssql> for i32 { @@ -70,10 +74,6 @@ impl Encode<'_, Mssql> for i32 { } impl Decode<'_, Mssql> for i32 { - fn accepts(ty: &MssqlTypeInfo) -> bool { - matches!(ty.0.ty, DataType::Int | DataType::IntN) && ty.0.size == 4 - } - fn decode(value: MssqlValueRef<'_>) -> Result { Ok(LittleEndian::read_i32(value.as_bytes()?)) } @@ -83,6 +83,10 @@ impl Type for i64 { fn type_info() -> MssqlTypeInfo { MssqlTypeInfo(TypeInfo::new(DataType::IntN, 8)) } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + matches!(ty.0.ty, DataType::BigInt | DataType::IntN) && ty.0.size == 8 + } } impl Encode<'_, Mssql> for i64 { @@ -94,10 +98,6 @@ impl Encode<'_, Mssql> for i64 { } impl Decode<'_, Mssql> for i64 { - fn accepts(ty: &MssqlTypeInfo) -> bool { - matches!(ty.0.ty, DataType::BigInt | DataType::IntN) && ty.0.size == 8 - } - fn decode(value: MssqlValueRef<'_>) -> Result { Ok(LittleEndian::read_i64(value.as_bytes()?)) } diff --git a/sqlx-core/src/mssql/types/mod.rs b/sqlx-core/src/mssql/types/mod.rs index 86505a85..148eaeae 100644 --- a/sqlx-core/src/mssql/types/mod.rs +++ b/sqlx-core/src/mssql/types/mod.rs @@ -8,15 +8,6 @@ mod int; mod str; impl<'q, T: 'q + Encode<'q, Mssql>> Encode<'q, Mssql> for Option { - fn produces(&self) -> Option { - if let Some(v) = self { - v.produces() - } else { - // MSSQL requires a special NULL type ID - Some(MssqlTypeInfo(TypeInfo::new(DataType::Null, 0))) - } - } - fn encode(self, buf: &mut Vec) -> IsNull { if let Some(v) = self { v.encode(buf) @@ -33,6 +24,15 @@ impl<'q, T: 'q + Encode<'q, Mssql>> Encode<'q, Mssql> for Option { } } + fn produces(&self) -> Option { + if let Some(v) = self { + v.produces() + } else { + // MSSQL requires a special NULL type ID + Some(MssqlTypeInfo(TypeInfo::new(DataType::Null, 0))) + } + } + fn size_hint(&self) -> usize { self.as_ref().map_or(0, Encode::size_hint) } diff --git a/sqlx-core/src/mssql/types/str.rs b/sqlx-core/src/mssql/types/str.rs index 2089042d..4902d783 100644 --- a/sqlx-core/src/mssql/types/str.rs +++ b/sqlx-core/src/mssql/types/str.rs @@ -10,12 +10,28 @@ impl Type for str { fn type_info() -> MssqlTypeInfo { MssqlTypeInfo(TypeInfo::new(DataType::NVarChar, 0)) } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + matches!( + ty.0.ty, + DataType::NVarChar + | DataType::NChar + | DataType::BigVarChar + | DataType::VarChar + | DataType::BigChar + | DataType::Char + ) + } } impl Type for String { fn type_info() -> MssqlTypeInfo { >::type_info() } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + >::compatible(ty) + } } impl Encode<'_, Mssql> for &'_ str { @@ -55,18 +71,6 @@ impl Encode<'_, Mssql> for String { } impl Decode<'_, Mssql> for String { - fn accepts(ty: &MssqlTypeInfo) -> bool { - matches!( - ty.0.ty, - DataType::NVarChar - | DataType::NChar - | DataType::BigVarChar - | DataType::VarChar - | DataType::BigChar - | DataType::Char - ) - } - fn decode(value: MssqlValueRef<'_>) -> Result { Ok(value .type_info diff --git a/sqlx-core/src/mysql/protocol/text/column.rs b/sqlx-core/src/mysql/protocol/text/column.rs index d7d42187..9d7dbb60 100644 --- a/sqlx-core/src/mysql/protocol/text/column.rs +++ b/sqlx-core/src/mysql/protocol/text/column.rs @@ -159,18 +159,25 @@ impl Decode<'_, Capabilities> for ColumnDefinition { } impl ColumnType { - pub(crate) fn name(self, char_set: u16) -> &'static str { + pub(crate) fn name(self, char_set: u16, flags: ColumnFlags) -> &'static str { let is_binary = char_set == 63; + let is_unsigned = flags.contains(ColumnFlags::UNSIGNED); + match self { + ColumnType::Tiny if is_unsigned => "TINYINT UNSIGNED", + ColumnType::Short if is_unsigned => "SMALLINT UNSIGNED", + ColumnType::Long if is_unsigned => "INT UNSIGNED", + ColumnType::Int24 if is_unsigned => "MEDIUMINT UNSIGNED", + ColumnType::LongLong if is_unsigned => "BIGINT UNSIGNED", ColumnType::Tiny => "TINYINT", ColumnType::Short => "SMALLINT", ColumnType::Long => "INT", + ColumnType::Int24 => "MEDIUMINT", + ColumnType::LongLong => "BIGINT", ColumnType::Float => "FLOAT", ColumnType::Double => "DOUBLE", ColumnType::Null => "NULL", ColumnType::Timestamp => "TIMESTAMP", - ColumnType::LongLong => "BIGINT", - ColumnType::Int24 => "MEDIUMINT", ColumnType::Date => "DATE", ColumnType::Time => "TIME", ColumnType::Datetime => "DATETIME", diff --git a/sqlx-core/src/mysql/type_info.rs b/sqlx-core/src/mysql/type_info.rs index 92e87fc5..f9ac50d3 100644 --- a/sqlx-core/src/mysql/type_info.rs +++ b/sqlx-core/src/mysql/type_info.rs @@ -59,21 +59,15 @@ impl MySqlTypeInfo { impl Display for MySqlTypeInfo { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - f.write_str(self.r#type.name(self.char_set))?; - - // NOTE: MariaDB flags timestamp columns as UNSIGNED but the type name - // does not have that suffix - if self.flags.contains(ColumnFlags::UNSIGNED) - && !self.flags.contains(ColumnFlags::TIMESTAMP) - { - f.write_str(" UNSIGNED")?; - } - - Ok(()) + f.pad(self.name()) } } -impl TypeInfo for MySqlTypeInfo {} +impl TypeInfo for MySqlTypeInfo { + fn name(&self) -> &str { + self.r#type.name(self.char_set, self.flags) + } +} impl PartialEq for MySqlTypeInfo { fn eq(&self, other: &MySqlTypeInfo) -> bool { diff --git a/sqlx-core/src/mysql/types/bigdecimal.rs b/sqlx-core/src/mysql/types/bigdecimal.rs index a768a43d..088e9cbf 100644 --- a/sqlx-core/src/mysql/types/bigdecimal.rs +++ b/sqlx-core/src/mysql/types/bigdecimal.rs @@ -1,6 +1,6 @@ use bigdecimal::BigDecimal; -use crate::decode::{accepts, Decode}; +use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::mysql::io::MySqlBufMutExt; @@ -23,10 +23,6 @@ impl Encode<'_, MySql> for BigDecimal { } impl Decode<'_, MySql> for BigDecimal { - fn accepts(ty: &MySqlTypeInfo) -> bool { - accepts::(ty) - } - fn decode(value: MySqlValueRef<'_>) -> Result { Ok(value.as_str()?.parse()?) } diff --git a/sqlx-core/src/mysql/types/bool.rs b/sqlx-core/src/mysql/types/bool.rs index d47a121a..c9e061e1 100644 --- a/sqlx-core/src/mysql/types/bool.rs +++ b/sqlx-core/src/mysql/types/bool.rs @@ -9,6 +9,10 @@ impl Type for bool { // MySQL has no actual `BOOLEAN` type, the type is an alias of `TINYINT(1)` >::type_info() } + + fn compatible(ty: &MySqlTypeInfo) -> bool { + >::compatible(ty) + } } impl Encode<'_, MySql> for bool { @@ -18,10 +22,6 @@ impl Encode<'_, MySql> for bool { } impl Decode<'_, MySql> for bool { - fn accepts(ty: &MySqlTypeInfo) -> bool { - >::accepts(ty) - } - fn decode(value: MySqlValueRef<'_>) -> Result { Ok(>::decode(value)? != 0) } diff --git a/sqlx-core/src/mysql/types/bytes.rs b/sqlx-core/src/mysql/types/bytes.rs index 0988f466..1426425a 100644 --- a/sqlx-core/src/mysql/types/bytes.rs +++ b/sqlx-core/src/mysql/types/bytes.rs @@ -10,18 +10,8 @@ impl Type for [u8] { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::binary(ColumnType::Blob) } -} -impl Encode<'_, MySql> for &'_ [u8] { - fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { - buf.put_bytes_lenenc(self); - - IsNull::No - } -} - -impl<'r> Decode<'r, MySql> for &'r [u8] { - fn accepts(ty: &MySqlTypeInfo) -> bool { + fn compatible(ty: &MySqlTypeInfo) -> bool { matches!( ty.r#type, ColumnType::VarChar @@ -34,7 +24,17 @@ impl<'r> Decode<'r, MySql> for &'r [u8] { | ColumnType::Enum ) } +} +impl Encode<'_, MySql> for &'_ [u8] { + fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + buf.put_bytes_lenenc(self); + + IsNull::No + } +} + +impl<'r> Decode<'r, MySql> for &'r [u8] { fn decode(value: MySqlValueRef<'r>) -> Result { value.as_bytes() } @@ -44,6 +44,10 @@ impl Type for Vec { fn type_info() -> MySqlTypeInfo { <[u8] as Type>::type_info() } + + fn compatible(ty: &MySqlTypeInfo) -> bool { + <&[u8] as Type>::compatible(ty) + } } impl Encode<'_, MySql> for Vec { @@ -53,10 +57,6 @@ impl Encode<'_, MySql> for Vec { } impl Decode<'_, MySql> for Vec { - fn accepts(ty: &MySqlTypeInfo) -> bool { - <&[u8] as Decode>::accepts(ty) - } - fn decode(value: MySqlValueRef<'_>) -> Result { <&[u8] as Decode>::decode(value).map(ToOwned::to_owned) } diff --git a/sqlx-core/src/mysql/types/chrono.rs b/sqlx-core/src/mysql/types/chrono.rs index 9461524f..da84c4ea 100644 --- a/sqlx-core/src/mysql/types/chrono.rs +++ b/sqlx-core/src/mysql/types/chrono.rs @@ -3,7 +3,7 @@ use std::convert::TryFrom; use bytes::Buf; use chrono::{DateTime, Datelike, NaiveDate, NaiveDateTime, NaiveTime, Timelike, Utc}; -use crate::decode::{accepts, Decode}; +use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::mysql::protocol::text::ColumnType; @@ -15,6 +15,10 @@ impl Type for DateTime { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::binary(ColumnType::Timestamp) } + + fn compatible(ty: &MySqlTypeInfo) -> bool { + matches!(ty.r#type, ColumnType::Datetime | ColumnType::Timestamp) + } } impl Encode<'_, MySql> for DateTime { @@ -24,10 +28,6 @@ impl Encode<'_, MySql> for DateTime { } impl<'r> Decode<'r, MySql> for DateTime { - fn accepts(ty: &MySqlTypeInfo) -> bool { - matches!(ty.r#type, ColumnType::Datetime | ColumnType::Timestamp) - } - fn decode(value: MySqlValueRef<'r>) -> Result { let naive: NaiveDateTime = Decode::::decode(value)?; @@ -70,10 +70,6 @@ impl Encode<'_, MySql> for NaiveTime { } impl<'r> Decode<'r, MySql> for NaiveTime { - fn accepts(ty: &MySqlTypeInfo) -> bool { - accepts::(ty) - } - fn decode(value: MySqlValueRef<'r>) -> Result { match value.format() { MySqlValueFormat::Binary => { @@ -122,10 +118,6 @@ impl Encode<'_, MySql> for NaiveDate { } impl<'r> Decode<'r, MySql> for NaiveDate { - fn accepts(ty: &MySqlTypeInfo) -> bool { - accepts::(ty) - } - fn decode(value: MySqlValueRef<'r>) -> Result { match value.format() { MySqlValueFormat::Binary => Ok(decode_date(&value.as_bytes()?[1..])), @@ -181,10 +173,6 @@ impl Encode<'_, MySql> for NaiveDateTime { } impl<'r> Decode<'r, MySql> for NaiveDateTime { - fn accepts(ty: &MySqlTypeInfo) -> bool { - matches!(ty.r#type, ColumnType::Datetime | ColumnType::Timestamp) - } - fn decode(value: MySqlValueRef<'r>) -> Result { match value.format() { MySqlValueFormat::Binary => { diff --git a/sqlx-core/src/mysql/types/float.rs b/sqlx-core/src/mysql/types/float.rs index 990bf20c..07a76895 100644 --- a/sqlx-core/src/mysql/types/float.rs +++ b/sqlx-core/src/mysql/types/float.rs @@ -7,7 +7,7 @@ use crate::mysql::protocol::text::ColumnType; use crate::mysql::{MySql, MySqlTypeInfo, MySqlValueFormat, MySqlValueRef}; use crate::types::Type; -fn real_accepts(ty: &MySqlTypeInfo) -> bool { +fn real_compatible(ty: &MySqlTypeInfo) -> bool { matches!(ty.r#type, ColumnType::Float | ColumnType::Double) } @@ -15,12 +15,20 @@ impl Type for f32 { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::binary(ColumnType::Float) } + + fn compatible(ty: &MySqlTypeInfo) -> bool { + real_compatible(ty) + } } impl Type for f64 { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::binary(ColumnType::Double) } + + fn compatible(ty: &MySqlTypeInfo) -> bool { + real_compatible(ty) + } } impl Encode<'_, MySql> for f32 { @@ -40,10 +48,6 @@ impl Encode<'_, MySql> for f64 { } impl Decode<'_, MySql> for f32 { - fn accepts(ty: &MySqlTypeInfo) -> bool { - real_accepts(ty) - } - fn decode(value: MySqlValueRef<'_>) -> Result { Ok(match value.format() { MySqlValueFormat::Binary => { @@ -64,10 +68,6 @@ impl Decode<'_, MySql> for f32 { } impl Decode<'_, MySql> for f64 { - fn accepts(ty: &MySqlTypeInfo) -> bool { - real_accepts(ty) - } - fn decode(value: MySqlValueRef<'_>) -> Result { Ok(match value.format() { MySqlValueFormat::Binary => LittleEndian::read_f64(value.as_bytes()?), diff --git a/sqlx-core/src/mysql/types/int.rs b/sqlx-core/src/mysql/types/int.rs index a09c4252..5884d30f 100644 --- a/sqlx-core/src/mysql/types/int.rs +++ b/sqlx-core/src/mysql/types/int.rs @@ -9,28 +9,55 @@ use crate::mysql::protocol::text::{ColumnFlags, ColumnType}; use crate::mysql::{MySql, MySqlTypeInfo, MySqlValueFormat, MySqlValueRef}; use crate::types::Type; +fn int_compatible(ty: &MySqlTypeInfo) -> bool { + matches!( + ty.r#type, + ColumnType::Tiny + | ColumnType::Short + | ColumnType::Long + | ColumnType::Int24 + | ColumnType::LongLong + ) && !ty.flags.contains(ColumnFlags::UNSIGNED) +} + impl Type for i8 { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::binary(ColumnType::Tiny) } + + fn compatible(ty: &MySqlTypeInfo) -> bool { + int_compatible(ty) + } } impl Type for i16 { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::binary(ColumnType::Short) } + + fn compatible(ty: &MySqlTypeInfo) -> bool { + int_compatible(ty) + } } impl Type for i32 { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::binary(ColumnType::Long) } + + fn compatible(ty: &MySqlTypeInfo) -> bool { + int_compatible(ty) + } } impl Type for i64 { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::binary(ColumnType::LongLong) } + + fn compatible(ty: &MySqlTypeInfo) -> bool { + int_compatible(ty) + } } impl Encode<'_, MySql> for i8 { @@ -65,17 +92,6 @@ impl Encode<'_, MySql> for i64 { } } -fn int_accepts(ty: &MySqlTypeInfo) -> bool { - matches!( - ty.r#type, - ColumnType::Tiny - | ColumnType::Short - | ColumnType::Long - | ColumnType::Int24 - | ColumnType::LongLong - ) && !ty.flags.contains(ColumnFlags::UNSIGNED) -} - fn int_decode(value: MySqlValueRef<'_>) -> Result { Ok(match value.format() { MySqlValueFormat::Text => value.as_str()?.parse()?, @@ -87,40 +103,24 @@ fn int_decode(value: MySqlValueRef<'_>) -> Result { } impl Decode<'_, MySql> for i8 { - fn accepts(ty: &MySqlTypeInfo) -> bool { - int_accepts(ty) - } - fn decode(value: MySqlValueRef<'_>) -> Result { int_decode(value)?.try_into().map_err(Into::into) } } impl Decode<'_, MySql> for i16 { - fn accepts(ty: &MySqlTypeInfo) -> bool { - int_accepts(ty) - } - fn decode(value: MySqlValueRef<'_>) -> Result { int_decode(value)?.try_into().map_err(Into::into) } } impl Decode<'_, MySql> for i32 { - fn accepts(ty: &MySqlTypeInfo) -> bool { - int_accepts(ty) - } - fn decode(value: MySqlValueRef<'_>) -> Result { int_decode(value)?.try_into().map_err(Into::into) } } impl Decode<'_, MySql> for i64 { - fn accepts(ty: &MySqlTypeInfo) -> bool { - int_accepts(ty) - } - fn decode(value: MySqlValueRef<'_>) -> Result { int_decode(value)?.try_into().map_err(Into::into) } diff --git a/sqlx-core/src/mysql/types/json.rs b/sqlx-core/src/mysql/types/json.rs index d21503b5..2ec32057 100644 --- a/sqlx-core/src/mysql/types/json.rs +++ b/sqlx-core/src/mysql/types/json.rs @@ -15,6 +15,12 @@ impl Type for Json { // and has nothing to do with the native storage ability of MySQL v8+ MySqlTypeInfo::binary(ColumnType::String) } + + fn compatible(ty: &MySqlTypeInfo) -> bool { + ty.r#type == ColumnType::Json + || <&str as Type>::compatible(ty) + || <&[u8] as Type>::compatible(ty) + } } impl Encode<'_, MySql> for Json @@ -33,12 +39,6 @@ impl<'r, T> Decode<'r, MySql> for Json where T: 'r + DeserializeOwned, { - fn accepts(ty: &MySqlTypeInfo) -> bool { - ty.r#type == ColumnType::Json - || <&str as Decode>::accepts(ty) - || <&[u8] as Decode>::accepts(ty) - } - fn decode(value: MySqlValueRef<'r>) -> Result { let string_value = <&str as Decode>::decode(value)?; diff --git a/sqlx-core/src/mysql/types/str.rs b/sqlx-core/src/mysql/types/str.rs index d3ce31cb..c88c347a 100644 --- a/sqlx-core/src/mysql/types/str.rs +++ b/sqlx-core/src/mysql/types/str.rs @@ -14,18 +14,8 @@ impl Type for str { flags: ColumnFlags::empty(), } } -} -impl Encode<'_, MySql> for &'_ str { - fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { - buf.put_str_lenenc(self); - - IsNull::No - } -} - -impl<'r> Decode<'r, MySql> for &'r str { - fn accepts(ty: &MySqlTypeInfo) -> bool { + fn compatible(ty: &MySqlTypeInfo) -> bool { matches!( ty.r#type, ColumnType::VarChar @@ -38,7 +28,17 @@ impl<'r> Decode<'r, MySql> for &'r str { | ColumnType::Enum ) && ty.char_set == 224 } +} +impl Encode<'_, MySql> for &'_ str { + fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + buf.put_str_lenenc(self); + + IsNull::No + } +} + +impl<'r> Decode<'r, MySql> for &'r str { fn decode(value: MySqlValueRef<'r>) -> Result { value.as_str() } @@ -48,6 +48,10 @@ impl Type for String { fn type_info() -> MySqlTypeInfo { >::type_info() } + + fn compatible(ty: &MySqlTypeInfo) -> bool { + >::compatible(ty) + } } impl Encode<'_, MySql> for String { @@ -57,10 +61,6 @@ impl Encode<'_, MySql> for String { } impl Decode<'_, MySql> for String { - fn accepts(ty: &MySqlTypeInfo) -> bool { - <&str as Decode>::accepts(ty) - } - fn decode(value: MySqlValueRef<'_>) -> Result { <&str as Decode>::decode(value).map(ToOwned::to_owned) } diff --git a/sqlx-core/src/mysql/types/time.rs b/sqlx-core/src/mysql/types/time.rs index 31c4c5ca..7ce89174 100644 --- a/sqlx-core/src/mysql/types/time.rs +++ b/sqlx-core/src/mysql/types/time.rs @@ -5,7 +5,7 @@ use byteorder::{ByteOrder, LittleEndian}; use bytes::Buf; use time::{Date, OffsetDateTime, PrimitiveDateTime, Time, UtcOffset}; -use crate::decode::{accepts, Decode}; +use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::mysql::protocol::text::ColumnType; @@ -17,6 +17,10 @@ impl Type for OffsetDateTime { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::binary(ColumnType::Timestamp) } + + fn compatible(ty: &MySqlTypeInfo) -> bool { + matches!(ty.r#type, ColumnType::Datetime | ColumnType::Timestamp) + } } impl Encode<'_, MySql> for OffsetDateTime { @@ -29,10 +33,6 @@ impl Encode<'_, MySql> for OffsetDateTime { } impl<'r> Decode<'r, MySql> for OffsetDateTime { - fn accepts(ty: &MySqlTypeInfo) -> bool { - matches!(ty.r#type, ColumnType::Datetime | ColumnType::Timestamp) - } - fn decode(value: MySqlValueRef<'r>) -> Result { let primitive: PrimitiveDateTime = Decode::::decode(value)?; @@ -75,10 +75,6 @@ impl Encode<'_, MySql> for Time { } impl<'r> Decode<'r, MySql> for Time { - fn accepts(ty: &MySqlTypeInfo) -> bool { - accepts::(ty) - } - fn decode(value: MySqlValueRef<'r>) -> Result { match value.format() { MySqlValueFormat::Binary => { @@ -138,10 +134,6 @@ impl Encode<'_, MySql> for Date { } impl<'r> Decode<'r, MySql> for Date { - fn accepts(ty: &MySqlTypeInfo) -> bool { - accepts::(ty) - } - fn decode(value: MySqlValueRef<'r>) -> Result { match value.format() { MySqlValueFormat::Binary => decode_date(&value.as_bytes()?[1..]), @@ -191,10 +183,6 @@ impl Encode<'_, MySql> for PrimitiveDateTime { } impl<'r> Decode<'r, MySql> for PrimitiveDateTime { - fn accepts(ty: &MySqlTypeInfo) -> bool { - matches!(ty.r#type, ColumnType::Datetime | ColumnType::Timestamp) - } - fn decode(value: MySqlValueRef<'r>) -> Result { match value.format() { MySqlValueFormat::Binary => { diff --git a/sqlx-core/src/mysql/types/uint.rs b/sqlx-core/src/mysql/types/uint.rs index 4a54fc9d..d57b4239 100644 --- a/sqlx-core/src/mysql/types/uint.rs +++ b/sqlx-core/src/mysql/types/uint.rs @@ -17,28 +17,55 @@ fn uint_type_info(ty: ColumnType) -> MySqlTypeInfo { } } +fn uint_compatible(ty: &MySqlTypeInfo) -> bool { + matches!( + ty.r#type, + ColumnType::Tiny + | ColumnType::Short + | ColumnType::Long + | ColumnType::Int24 + | ColumnType::LongLong + ) && ty.flags.contains(ColumnFlags::UNSIGNED) +} + impl Type for u8 { fn type_info() -> MySqlTypeInfo { uint_type_info(ColumnType::Tiny) } + + fn compatible(ty: &MySqlTypeInfo) -> bool { + uint_compatible(ty) + } } impl Type for u16 { fn type_info() -> MySqlTypeInfo { uint_type_info(ColumnType::Short) } + + fn compatible(ty: &MySqlTypeInfo) -> bool { + uint_compatible(ty) + } } impl Type for u32 { fn type_info() -> MySqlTypeInfo { uint_type_info(ColumnType::Long) } + + fn compatible(ty: &MySqlTypeInfo) -> bool { + uint_compatible(ty) + } } impl Type for u64 { fn type_info() -> MySqlTypeInfo { uint_type_info(ColumnType::LongLong) } + + fn compatible(ty: &MySqlTypeInfo) -> bool { + uint_compatible(ty) + } } impl Encode<'_, MySql> for u8 { @@ -73,17 +100,6 @@ impl Encode<'_, MySql> for u64 { } } -fn uint_accepts(ty: &MySqlTypeInfo) -> bool { - matches!( - ty.r#type, - ColumnType::Tiny - | ColumnType::Short - | ColumnType::Long - | ColumnType::Int24 - | ColumnType::LongLong - ) && ty.flags.contains(ColumnFlags::UNSIGNED) -} - fn uint_decode(value: MySqlValueRef<'_>) -> Result { Ok(match value.format() { MySqlValueFormat::Text => value.as_str()?.parse()?, @@ -95,40 +111,24 @@ fn uint_decode(value: MySqlValueRef<'_>) -> Result { } impl Decode<'_, MySql> for u8 { - fn accepts(ty: &MySqlTypeInfo) -> bool { - uint_accepts(ty) - } - fn decode(value: MySqlValueRef<'_>) -> Result { uint_decode(value)?.try_into().map_err(Into::into) } } impl Decode<'_, MySql> for u16 { - fn accepts(ty: &MySqlTypeInfo) -> bool { - uint_accepts(ty) - } - fn decode(value: MySqlValueRef<'_>) -> Result { uint_decode(value)?.try_into().map_err(Into::into) } } impl Decode<'_, MySql> for u32 { - fn accepts(ty: &MySqlTypeInfo) -> bool { - uint_accepts(ty) - } - fn decode(value: MySqlValueRef<'_>) -> Result { uint_decode(value)?.try_into().map_err(Into::into) } } impl Decode<'_, MySql> for u64 { - fn accepts(ty: &MySqlTypeInfo) -> bool { - uint_accepts(ty) - } - fn decode(value: MySqlValueRef<'_>) -> Result { uint_decode(value)?.try_into().map_err(Into::into) } diff --git a/sqlx-core/src/postgres/mod.rs b/sqlx-core/src/postgres/mod.rs index fde4a38c..c49cc7a6 100644 --- a/sqlx-core/src/postgres/mod.rs +++ b/sqlx-core/src/postgres/mod.rs @@ -23,7 +23,7 @@ pub use message::PgSeverity; pub use options::{PgConnectOptions, PgSslMode}; pub use row::PgRow; pub use transaction::PgTransactionManager; -pub use type_info::{PgTypeInfo, PgTypeKind}; +pub use type_info::PgTypeInfo; pub use value::{PgValue, PgValueFormat, PgValueRef}; /// An alias for [`Pool`][crate::pool::Pool], specialized for Postgres. diff --git a/sqlx-core/src/postgres/type_info.rs b/sqlx-core/src/postgres/type_info.rs index 3047e1db..9baf7353 100644 --- a/sqlx-core/src/postgres/type_info.rs +++ b/sqlx-core/src/postgres/type_info.rs @@ -1,6 +1,7 @@ #![allow(dead_code)] use std::fmt::{self, Display, Formatter}; +use std::ops::Deref; use std::sync::Arc; use crate::ext::ustr::UStr; @@ -11,10 +12,18 @@ use crate::type_info::TypeInfo; #[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] pub struct PgTypeInfo(pub(crate) PgType); +impl Deref for PgTypeInfo { + type Target = PgType; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + #[derive(Debug, Clone)] #[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] #[repr(u32)] -pub(crate) enum PgType { +pub enum PgType { Bool, Bytea, Char, @@ -118,7 +127,7 @@ pub(crate) enum PgType { #[derive(Debug, Clone)] #[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] -pub(crate) struct PgCustomType { +pub struct PgCustomType { pub(crate) oid: u32, pub(crate) name: UStr, pub(crate) kind: PgTypeKind, @@ -418,6 +427,102 @@ impl PgType { }) } + pub(crate) fn display_name(&self) -> &str { + match self { + PgType::Bool => "BOOL", + PgType::Bytea => "BYTEA", + PgType::Char => "\"CHAR\"", + PgType::Name => "NAME", + PgType::Int8 => "INT8", + PgType::Int2 => "INT2", + PgType::Int4 => "INT4", + PgType::Text => "TEXT", + PgType::Oid => "OID", + PgType::Json => "JSON", + PgType::JsonArray => "JSON[]", + PgType::Point => "POINT", + PgType::Lseg => "LSEG", + PgType::Path => "PATH", + PgType::Box => "BOX", + PgType::Polygon => "POLYGON", + PgType::Line => "LINE", + PgType::LineArray => "LINE[]", + PgType::Cidr => "CIDR", + PgType::CidrArray => "CIDR[]", + PgType::Float4 => "FLOAT4", + PgType::Float8 => "FLOAT8", + PgType::Unknown => "UNKNOWN", + PgType::Circle => "CIRCLE", + PgType::CircleArray => "CIRCLE[]", + PgType::Macaddr8 => "MACADDR8", + PgType::Macaddr8Array => "MACADDR8[]", + PgType::Macaddr => "MACADDR", + PgType::Inet => "INET", + PgType::BoolArray => "BOOL[]", + PgType::ByteaArray => "BYTEA[]", + PgType::CharArray => "\"CHAR\"[]", + PgType::NameArray => "NAME[]", + PgType::Int2Array => "INT2[]", + PgType::Int4Array => "INT4[]", + PgType::TextArray => "TEXT[]", + PgType::BpcharArray => "CHAR[]", + PgType::VarcharArray => "VARCHAR[]", + PgType::Int8Array => "INT8[]", + PgType::PointArray => "POINT[]", + PgType::LsegArray => "LSEG[]", + PgType::PathArray => "PATH[]", + PgType::BoxArray => "BOX[]", + PgType::Float4Array => "FLOAT4[]", + PgType::Float8Array => "FLOAT8[]", + PgType::PolygonArray => "POLYGON[]", + PgType::OidArray => "OID[]", + PgType::MacaddrArray => "MACADDR[]", + PgType::InetArray => "INET[]", + PgType::Bpchar => "CHAR", + PgType::Varchar => "VARCHAR", + PgType::Date => "DATE", + PgType::Time => "TIME", + PgType::Timestamp => "TIMESTAMP", + PgType::TimestampArray => "TIMESTAMP[]", + PgType::DateArray => "DATE[]", + PgType::TimeArray => "TIME[]", + PgType::Timestamptz => "TIMESTAMPTZ", + PgType::TimestamptzArray => "TIMESTAMPTZ[]", + PgType::NumericArray => "NUMERIC[]", + PgType::Timetz => "TIMETZ", + PgType::TimetzArray => "TIMETZ[]", + PgType::Bit => "BIT", + PgType::BitArray => "BIT[]", + PgType::Varbit => "VARBIT", + PgType::VarbitArray => "VARBIT[]", + PgType::Numeric => "NUMERIC", + PgType::Record => "RECORD", + PgType::Interval => "INTERVAL", + PgType::RecordArray => "RECORD[]", + PgType::Uuid => "UUID", + PgType::UuidArray => "UUID[]", + PgType::Jsonb => "JSONB", + PgType::JsonbArray => "JSONB[]", + PgType::Int4Range => "INT4RANGE", + PgType::Int4RangeArray => "INT4RANGE[]", + PgType::NumRange => "NUMRANGE", + PgType::NumRangeArray => "NUMRANGE[]", + PgType::TsRange => "TSRANGE", + PgType::TsRangeArray => "TSRANGE[]", + PgType::TstzRange => "TSTZRANGE", + PgType::TstzRangeArray => "TSTZRANGE[]", + PgType::DateRange => "DATERANGE", + PgType::DateRangeArray => "DATERANGE[]", + PgType::Int8Range => "INT8RANGE", + PgType::Int8RangeArray => "INT8RANGE[]", + PgType::Jsonpath => "JSONPATH", + PgType::JsonpathArray => "JSONPATH[]", + PgType::Custom(ty) => &*ty.name, + PgType::DeclareWithOid(_) => "?", + PgType::DeclareWithName(name) => name, + } + } + pub(crate) fn name(&self) -> &str { match self { PgType::Bool => "bool", @@ -613,7 +718,11 @@ impl PgType { } } -impl TypeInfo for PgTypeInfo {} +impl TypeInfo for PgTypeInfo { + fn name(&self) -> &str { + self.0.display_name() + } +} impl PartialEq for PgCustomType { fn eq(&self, other: &PgCustomType) -> bool { diff --git a/sqlx-core/src/postgres/types/array.rs b/sqlx-core/src/postgres/types/array.rs index 517dc119..9dedfa03 100644 --- a/sqlx-core/src/postgres/types/array.rs +++ b/sqlx-core/src/postgres/types/array.rs @@ -1,6 +1,6 @@ use bytes::Buf; -use crate::decode::{accepts, Decode}; +use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::postgres::type_info::PgType; @@ -14,6 +14,10 @@ where fn type_info() -> PgTypeInfo { <[T] as Type>::type_info() } + + fn compatible(ty: &PgTypeInfo) -> bool { + <[T] as Type>::compatible(ty) + } } impl Type for Vec> @@ -23,6 +27,10 @@ where fn type_info() -> PgTypeInfo { as Type>::type_info() } + + fn compatible(ty: &PgTypeInfo) -> bool { + as Type>::compatible(ty) + } } impl<'q, T> Encode<'q, Postgres> for Vec @@ -66,18 +74,11 @@ where } } -// TODO: Array decoding in PostgreSQL *could* allow 'r (row) lifetime of elements if we can figure -// out a way for the TEXT encoding to use some shared memory somewhere. - impl<'r, T> Decode<'r, Postgres> for Vec where T: for<'a> Decode<'a, Postgres> + Type, Self: Type, { - fn accepts(ty: &PgTypeInfo) -> bool { - accepts::(ty) - } - fn decode(value: PgValueRef<'r>) -> Result { let element_type_info = T::type_info(); let format = value.format(); diff --git a/sqlx-core/src/postgres/types/bigdecimal.rs b/sqlx-core/src/postgres/types/bigdecimal.rs index 7383a1d8..cb3cfd43 100644 --- a/sqlx-core/src/postgres/types/bigdecimal.rs +++ b/sqlx-core/src/postgres/types/bigdecimal.rs @@ -4,7 +4,7 @@ use std::convert::{TryFrom, TryInto}; use bigdecimal::BigDecimal; use num_bigint::{BigInt, Sign}; -use crate::decode::{accepts, Decode}; +use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::postgres::types::numeric::{PgNumeric, PgNumericSign}; @@ -165,10 +165,6 @@ impl Encode<'_, Postgres> for BigDecimal { } impl Decode<'_, Postgres> for BigDecimal { - fn accepts(ty: &PgTypeInfo) -> bool { - accepts::(ty) - } - fn decode(value: PgValueRef<'_>) -> Result { match value.format() { PgValueFormat::Binary => PgNumeric::decode(value.as_bytes()?)?.try_into(), diff --git a/sqlx-core/src/postgres/types/bool.rs b/sqlx-core/src/postgres/types/bool.rs index fb8def18..457d07ce 100644 --- a/sqlx-core/src/postgres/types/bool.rs +++ b/sqlx-core/src/postgres/types/bool.rs @@ -1,4 +1,4 @@ -use crate::decode::{accepts, Decode}; +use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::postgres::{PgArgumentBuffer, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; @@ -31,10 +31,6 @@ impl Encode<'_, Postgres> for bool { } impl Decode<'_, Postgres> for bool { - fn accepts(ty: &PgTypeInfo) -> bool { - accepts::(ty) - } - fn decode(value: PgValueRef<'_>) -> Result { Ok(match value.format() { PgValueFormat::Binary => value.as_bytes()?[0] != 0, diff --git a/sqlx-core/src/postgres/types/bytes.rs b/sqlx-core/src/postgres/types/bytes.rs index ba3b84fb..37d158ad 100644 --- a/sqlx-core/src/postgres/types/bytes.rs +++ b/sqlx-core/src/postgres/types/bytes.rs @@ -1,4 +1,4 @@ -use crate::decode::{accepts, Decode}; +use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::postgres::{PgArgumentBuffer, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; @@ -55,10 +55,6 @@ impl Encode<'_, Postgres> for Vec { } impl<'r> Decode<'r, Postgres> for &'r [u8] { - fn accepts(ty: &PgTypeInfo) -> bool { - accepts::(ty) - } - fn decode(value: PgValueRef<'r>) -> Result { match value.format() { PgValueFormat::Binary => value.as_bytes(), @@ -70,10 +66,6 @@ impl<'r> Decode<'r, Postgres> for &'r [u8] { } impl Decode<'_, Postgres> for Vec { - fn accepts(ty: &PgTypeInfo) -> bool { - accepts::(ty) - } - fn decode(value: PgValueRef<'_>) -> Result { Ok(match value.format() { PgValueFormat::Binary => value.as_bytes()?.to_owned(), diff --git a/sqlx-core/src/postgres/types/chrono.rs b/sqlx-core/src/postgres/types/chrono.rs index 97352f7f..19ac89a8 100644 --- a/sqlx-core/src/postgres/types/chrono.rs +++ b/sqlx-core/src/postgres/types/chrono.rs @@ -2,7 +2,7 @@ use std::mem; use chrono::{DateTime, Duration, Local, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Utc}; -use crate::decode::{accepts, Decode}; +use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::postgres::{PgArgumentBuffer, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; @@ -96,10 +96,6 @@ impl Encode<'_, Postgres> for NaiveTime { } impl<'r> Decode<'r, Postgres> for NaiveTime { - fn accepts(ty: &PgTypeInfo) -> bool { - accepts::(ty) - } - fn decode(value: PgValueRef<'r>) -> Result { Ok(match value.format() { PgValueFormat::Binary => { @@ -126,10 +122,6 @@ impl Encode<'_, Postgres> for NaiveDate { } impl<'r> Decode<'r, Postgres> for NaiveDate { - fn accepts(ty: &PgTypeInfo) -> bool { - accepts::(ty) - } - fn decode(value: PgValueRef<'r>) -> Result { Ok(match value.format() { PgValueFormat::Binary => { @@ -160,10 +152,6 @@ impl Encode<'_, Postgres> for NaiveDateTime { } impl<'r> Decode<'r, Postgres> for NaiveDateTime { - fn accepts(ty: &PgTypeInfo) -> bool { - accepts::(ty) - } - fn decode(value: PgValueRef<'r>) -> Result { Ok(match value.format() { PgValueFormat::Binary => { @@ -202,10 +190,6 @@ impl Encode<'_, Postgres> for DateTime { } impl<'r> Decode<'r, Postgres> for DateTime { - fn accepts(ty: &PgTypeInfo) -> bool { - accepts::(ty) - } - fn decode(value: PgValueRef<'r>) -> Result { let naive = >::decode(value)?; Ok(Local.from_utc_datetime(&naive)) @@ -213,10 +197,6 @@ impl<'r> Decode<'r, Postgres> for DateTime { } impl<'r> Decode<'r, Postgres> for DateTime { - fn accepts(ty: &PgTypeInfo) -> bool { - accepts::(ty) - } - fn decode(value: PgValueRef<'r>) -> Result { let naive = >::decode(value)?; Ok(Utc.from_utc_datetime(&naive)) diff --git a/sqlx-core/src/postgres/types/float.rs b/sqlx-core/src/postgres/types/float.rs index 34b1932f..76cb4107 100644 --- a/sqlx-core/src/postgres/types/float.rs +++ b/sqlx-core/src/postgres/types/float.rs @@ -1,6 +1,6 @@ use byteorder::{BigEndian, ByteOrder}; -use crate::decode::{accepts, Decode}; +use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::postgres::{PgArgumentBuffer, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; @@ -33,10 +33,6 @@ impl Encode<'_, Postgres> for f32 { } impl Decode<'_, Postgres> for f32 { - fn accepts(ty: &PgTypeInfo) -> bool { - accepts::(ty) - } - fn decode(value: PgValueRef<'_>) -> Result { Ok(match value.format() { PgValueFormat::Binary => BigEndian::read_f32(value.as_bytes()?), @@ -72,10 +68,6 @@ impl Encode<'_, Postgres> for f64 { } impl Decode<'_, Postgres> for f64 { - fn accepts(ty: &PgTypeInfo) -> bool { - accepts::(ty) - } - fn decode(value: PgValueRef<'_>) -> Result { Ok(match value.format() { PgValueFormat::Binary => BigEndian::read_f64(value.as_bytes()?), diff --git a/sqlx-core/src/postgres/types/int.rs b/sqlx-core/src/postgres/types/int.rs index d47b2229..4fc2bf39 100644 --- a/sqlx-core/src/postgres/types/int.rs +++ b/sqlx-core/src/postgres/types/int.rs @@ -1,6 +1,6 @@ use byteorder::{BigEndian, ByteOrder}; -use crate::decode::{accepts, Decode}; +use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::postgres::{PgArgumentBuffer, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; @@ -33,10 +33,6 @@ impl Encode<'_, Postgres> for i8 { } impl Decode<'_, Postgres> for i8 { - fn accepts(ty: &PgTypeInfo) -> bool { - accepts::(ty) - } - fn decode(value: PgValueRef<'_>) -> Result { // note: in the TEXT encoding, a value of "0" here is encoded as an empty string Ok(value.as_bytes()?.get(0).copied().unwrap_or_default() as i8) @@ -70,10 +66,6 @@ impl Encode<'_, Postgres> for i16 { } impl Decode<'_, Postgres> for i16 { - fn accepts(ty: &PgTypeInfo) -> bool { - accepts::(ty) - } - fn decode(value: PgValueRef<'_>) -> Result { Ok(match value.format() { PgValueFormat::Binary => BigEndian::read_i16(value.as_bytes()?), @@ -109,10 +101,6 @@ impl Encode<'_, Postgres> for u32 { } impl Decode<'_, Postgres> for u32 { - fn accepts(ty: &PgTypeInfo) -> bool { - accepts::(ty) - } - fn decode(value: PgValueRef<'_>) -> Result { Ok(match value.format() { PgValueFormat::Binary => BigEndian::read_u32(value.as_bytes()?), @@ -148,10 +136,6 @@ impl Encode<'_, Postgres> for i32 { } impl Decode<'_, Postgres> for i32 { - fn accepts(ty: &PgTypeInfo) -> bool { - accepts::(ty) - } - fn decode(value: PgValueRef<'_>) -> Result { Ok(match value.format() { PgValueFormat::Binary => BigEndian::read_i32(value.as_bytes()?), @@ -187,10 +171,6 @@ impl Encode<'_, Postgres> for i64 { } impl Decode<'_, Postgres> for i64 { - fn accepts(ty: &PgTypeInfo) -> bool { - accepts::(ty) - } - fn decode(value: PgValueRef<'_>) -> Result { Ok(match value.format() { PgValueFormat::Binary => BigEndian::read_i64(value.as_bytes()?), diff --git a/sqlx-core/src/postgres/types/ipnetwork.rs b/sqlx-core/src/postgres/types/ipnetwork.rs index cc1c3864..5d579e86 100644 --- a/sqlx-core/src/postgres/types/ipnetwork.rs +++ b/sqlx-core/src/postgres/types/ipnetwork.rs @@ -28,6 +28,10 @@ impl Type for IpNetwork { fn type_info() -> PgTypeInfo { PgTypeInfo::INET } + + fn compatible(ty: &PgTypeInfo) -> bool { + *ty == PgTypeInfo::CIDR || *ty == PgTypeInfo::INET + } } impl Type for [IpNetwork] { @@ -40,6 +44,10 @@ impl Type for Vec { fn type_info() -> PgTypeInfo { <[IpNetwork] as Type>::type_info() } + + fn compatible(ty: &PgTypeInfo) -> bool { + <[IpNetwork] as Type>::compatible(ty) + } } impl Encode<'_, Postgres> for IpNetwork { @@ -77,10 +85,6 @@ impl Encode<'_, Postgres> for IpNetwork { } impl Decode<'_, Postgres> for IpNetwork { - fn accepts(ty: &PgTypeInfo) -> bool { - *ty == PgTypeInfo::CIDR || *ty == PgTypeInfo::INET - } - fn decode(value: PgValueRef<'_>) -> Result { let bytes = match value.format() { PgValueFormat::Binary => value.as_bytes()?, diff --git a/sqlx-core/src/postgres/types/json.rs b/sqlx-core/src/postgres/types/json.rs index 11854acf..64cd884c 100644 --- a/sqlx-core/src/postgres/types/json.rs +++ b/sqlx-core/src/postgres/types/json.rs @@ -3,6 +3,7 @@ use serde::{Deserialize, Serialize}; use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; +use crate::postgres::types::array_compatible; use crate::postgres::{PgArgumentBuffer, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; use crate::types::{Json, Type}; @@ -16,12 +17,20 @@ impl Type for Json { fn type_info() -> PgTypeInfo { PgTypeInfo::JSONB } + + fn compatible(ty: &PgTypeInfo) -> bool { + *ty == PgTypeInfo::JSON || *ty == PgTypeInfo::JSONB + } } impl Type for [Json] { fn type_info() -> PgTypeInfo { PgTypeInfo::JSONB_ARRAY } + + fn compatible(ty: &PgTypeInfo) -> bool { + array_compatible::>(ty) + } } impl Type for Vec> { @@ -49,10 +58,6 @@ impl<'r, T: 'r> Decode<'r, Postgres> for Json where T: Deserialize<'r>, { - fn accepts(ty: &PgTypeInfo) -> bool { - *ty == PgTypeInfo::JSON || *ty == PgTypeInfo::JSONB - } - fn decode(value: PgValueRef<'r>) -> Result { let mut buf = value.as_bytes()?; diff --git a/sqlx-core/src/postgres/types/mod.rs b/sqlx-core/src/postgres/types/mod.rs index ce355c54..b65c29d4 100644 --- a/sqlx-core/src/postgres/types/mod.rs +++ b/sqlx-core/src/postgres/types/mod.rs @@ -128,6 +128,10 @@ //! a potentially `NULL` value from Postgres. //! +use crate::postgres::type_info::PgTypeKind; +use crate::postgres::{PgTypeInfo, Postgres}; +use crate::types::Type; + mod array; mod bool; mod bytes; @@ -165,3 +169,14 @@ pub use range::PgRange; // but the interface is not considered part of the public API #[doc(hidden)] pub use record::{PgRecordDecoder, PgRecordEncoder}; + +// Type::compatible impl appropriate for arrays +fn array_compatible>(ty: &PgTypeInfo) -> bool { + // we require the declared type to be an _array_ with an + // element type that is acceptable + if let PgTypeKind::Array(element) = &ty.kind() { + return E::compatible(&element); + } + + false +} diff --git a/sqlx-core/src/postgres/types/range.rs b/sqlx-core/src/postgres/types/range.rs index 56138982..760249f7 100644 --- a/sqlx-core/src/postgres/types/range.rs +++ b/sqlx-core/src/postgres/types/range.rs @@ -7,9 +7,8 @@ use bytes::Buf; use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; -use crate::postgres::{ - PgArgumentBuffer, PgTypeInfo, PgTypeKind, PgValueFormat, PgValueRef, Postgres, -}; +use crate::postgres::type_info::PgTypeKind; +use crate::postgres::{PgArgumentBuffer, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; use crate::types::Type; // https://github.com/postgres/postgres/blob/2f48ede080f42b97b594fb14102c82ca1001b80c/src/include/utils/rangetypes.h#L35-L44 @@ -116,12 +115,20 @@ impl Type for PgRange { fn type_info() -> PgTypeInfo { PgTypeInfo::INT4_RANGE } + + fn compatible(ty: &PgTypeInfo) -> bool { + range_compatible::(ty) + } } impl Type for PgRange { fn type_info() -> PgTypeInfo { PgTypeInfo::INT8_RANGE } + + fn compatible(ty: &PgTypeInfo) -> bool { + range_compatible::(ty) + } } #[cfg(feature = "bigdecimal")] @@ -129,6 +136,10 @@ impl Type for PgRange { fn type_info() -> PgTypeInfo { PgTypeInfo::NUM_RANGE } + + fn compatible(ty: &PgTypeInfo) -> bool { + range_compatible::(ty) + } } #[cfg(feature = "chrono")] @@ -136,6 +147,10 @@ impl Type for PgRange { fn type_info() -> PgTypeInfo { PgTypeInfo::DATE_RANGE } + + fn compatible(ty: &PgTypeInfo) -> bool { + range_compatible::(ty) + } } #[cfg(feature = "chrono")] @@ -143,6 +158,10 @@ impl Type for PgRange { fn type_info() -> PgTypeInfo { PgTypeInfo::TS_RANGE } + + fn compatible(ty: &PgTypeInfo) -> bool { + range_compatible::(ty) + } } #[cfg(feature = "chrono")] @@ -150,6 +169,10 @@ impl Type for PgRange> { fn type_info() -> PgTypeInfo { PgTypeInfo::TSTZ_RANGE } + + fn compatible(ty: &PgTypeInfo) -> bool { + range_compatible::>(ty) + } } #[cfg(feature = "time")] @@ -157,6 +180,10 @@ impl Type for PgRange { fn type_info() -> PgTypeInfo { PgTypeInfo::DATE_RANGE } + + fn compatible(ty: &PgTypeInfo) -> bool { + range_compatible::(ty) + } } #[cfg(feature = "time")] @@ -164,6 +191,10 @@ impl Type for PgRange { fn type_info() -> PgTypeInfo { PgTypeInfo::TS_RANGE } + + fn compatible(ty: &PgTypeInfo) -> bool { + range_compatible::(ty) + } } #[cfg(feature = "time")] @@ -171,6 +202,10 @@ impl Type for PgRange { fn type_info() -> PgTypeInfo { PgTypeInfo::TSTZ_RANGE } + + fn compatible(ty: &PgTypeInfo) -> bool { + range_compatible::(ty) + } } impl Type for [PgRange] { @@ -335,16 +370,6 @@ impl<'r, T> Decode<'r, Postgres> for PgRange where T: Type + for<'a> Decode<'a, Postgres>, { - fn accepts(ty: &PgTypeInfo) -> bool { - // we require the declared type to be a _range_ with an - // element type that is acceptable - if let PgTypeKind::Range(element) = &ty.0.kind() { - return T::accepts(&element); - } - - false - } - fn decode(value: PgValueRef<'r>) -> Result { match value.format { PgValueFormat::Binary => { @@ -528,3 +553,13 @@ where Ok(()) } } + +fn range_compatible>(ty: &PgTypeInfo) -> bool { + // we require the declared type to be a _range_ with an + // element type that is acceptable + if let PgTypeKind::Range(element) = &ty.kind() { + return E::compatible(&element); + } + + false +} diff --git a/sqlx-core/src/postgres/types/record.rs b/sqlx-core/src/postgres/types/record.rs index c059c09e..b58ab72c 100644 --- a/sqlx-core/src/postgres/types/record.rs +++ b/sqlx-core/src/postgres/types/record.rs @@ -3,10 +3,8 @@ use bytes::Buf; use crate::decode::Decode; use crate::encode::Encode; use crate::error::{mismatched_types, BoxDynError}; -use crate::postgres::type_info::PgType; -use crate::postgres::{ - PgArgumentBuffer, PgTypeInfo, PgTypeKind, PgValueFormat, PgValueRef, Postgres, -}; +use crate::postgres::type_info::{PgType, PgTypeKind}; +use crate::postgres::{PgArgumentBuffer, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; use crate::types::Type; #[doc(hidden)] @@ -128,8 +126,8 @@ impl<'r> PgRecordDecoder<'r> { self.ind += 1; if let Some(ty) = &element_type_opt { - if !T::accepts(ty) { - return Err(mismatched_types::(&T::type_info(), ty)); + if !T::compatible(ty) { + return Err(mismatched_types::(ty)); } } diff --git a/sqlx-core/src/postgres/types/str.rs b/sqlx-core/src/postgres/types/str.rs index 1a6e7d16..3607a4b8 100644 --- a/sqlx-core/src/postgres/types/str.rs +++ b/sqlx-core/src/postgres/types/str.rs @@ -1,6 +1,7 @@ use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; +use crate::postgres::types::array_compatible; use crate::postgres::{PgArgumentBuffer, PgTypeInfo, PgValueRef, Postgres}; use crate::types::Type; @@ -8,18 +9,37 @@ impl Type for str { fn type_info() -> PgTypeInfo { PgTypeInfo::TEXT } + + fn compatible(ty: &PgTypeInfo) -> bool { + [ + PgTypeInfo::TEXT, + PgTypeInfo::NAME, + PgTypeInfo::BPCHAR, + PgTypeInfo::VARCHAR, + PgTypeInfo::UNKNOWN, + ] + .contains(ty) + } } impl Type for [&'_ str] { fn type_info() -> PgTypeInfo { PgTypeInfo::TEXT_ARRAY } + + fn compatible(ty: &PgTypeInfo) -> bool { + array_compatible::<&str>(ty) + } } impl Type for Vec<&'_ str> { fn type_info() -> PgTypeInfo { <[&str] as Type>::type_info() } + + fn compatible(ty: &PgTypeInfo) -> bool { + <[&str] as Type>::compatible(ty) + } } impl Encode<'_, Postgres> for &'_ str { @@ -37,17 +57,6 @@ impl Encode<'_, Postgres> for String { } impl<'r> Decode<'r, Postgres> for &'r str { - fn accepts(ty: &PgTypeInfo) -> bool { - [ - PgTypeInfo::TEXT, - PgTypeInfo::NAME, - PgTypeInfo::BPCHAR, - PgTypeInfo::VARCHAR, - PgTypeInfo::UNKNOWN, - ] - .contains(ty) - } - fn decode(value: PgValueRef<'r>) -> Result { Ok(value.as_str()?) } @@ -57,25 +66,33 @@ impl Type for String { fn type_info() -> PgTypeInfo { <&str as Type>::type_info() } + + fn compatible(ty: &PgTypeInfo) -> bool { + <&str as Type>::compatible(ty) + } } impl Type for [String] { fn type_info() -> PgTypeInfo { <[&str] as Type>::type_info() } + + fn compatible(ty: &PgTypeInfo) -> bool { + <[&str] as Type>::compatible(ty) + } } impl Type for Vec { fn type_info() -> PgTypeInfo { <[String] as Type>::type_info() } + + fn compatible(ty: &PgTypeInfo) -> bool { + <[String] as Type>::compatible(ty) + } } impl Decode<'_, Postgres> for String { - fn accepts(ty: &PgTypeInfo) -> bool { - <&str as Decode>::accepts(ty) - } - fn decode(value: PgValueRef<'_>) -> Result { Ok(value.as_str()?.to_owned()) } diff --git a/sqlx-core/src/postgres/types/time.rs b/sqlx-core/src/postgres/types/time.rs index ed357cac..3338680d 100644 --- a/sqlx-core/src/postgres/types/time.rs +++ b/sqlx-core/src/postgres/types/time.rs @@ -1,6 +1,6 @@ use time::{date, offset, Date, Duration, OffsetDateTime, PrimitiveDateTime, Time}; -use crate::decode::{accepts, Decode}; +use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::postgres::{PgArgumentBuffer, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; @@ -96,10 +96,6 @@ impl Encode<'_, Postgres> for Time { } impl<'r> Decode<'r, Postgres> for Time { - fn accepts(ty: &PgTypeInfo) -> bool { - accepts::(ty) - } - fn decode(value: PgValueRef<'r>) -> Result { Ok(match value.format() { PgValueFormat::Binary => { @@ -141,10 +137,6 @@ impl Encode<'_, Postgres> for Date { } impl<'r> Decode<'r, Postgres> for Date { - fn accepts(ty: &PgTypeInfo) -> bool { - accepts::(ty) - } - fn decode(value: PgValueRef<'r>) -> Result { Ok(match value.format() { PgValueFormat::Binary => { @@ -171,10 +163,6 @@ impl Encode<'_, Postgres> for PrimitiveDateTime { } impl<'r> Decode<'r, Postgres> for PrimitiveDateTime { - fn accepts(ty: &PgTypeInfo) -> bool { - accepts::(ty) - } - fn decode(value: PgValueRef<'r>) -> Result { Ok(match value.format() { PgValueFormat::Binary => { @@ -232,10 +220,6 @@ impl Encode<'_, Postgres> for OffsetDateTime { } impl<'r> Decode<'r, Postgres> for OffsetDateTime { - fn accepts(ty: &PgTypeInfo) -> bool { - accepts::(ty) - } - fn decode(value: PgValueRef<'r>) -> Result { Ok(>::decode(value)?.assume_utc()) } diff --git a/sqlx-core/src/postgres/types/tuple.rs b/sqlx-core/src/postgres/types/tuple.rs index 53e7145d..67a5cb27 100644 --- a/sqlx-core/src/postgres/types/tuple.rs +++ b/sqlx-core/src/postgres/types/tuple.rs @@ -1,4 +1,4 @@ -use crate::decode::{accepts, Decode}; +use crate::decode::Decode; use crate::error::BoxDynError; use crate::postgres::types::PgRecordDecoder; use crate::postgres::{PgTypeInfo, PgValueRef, Postgres}; @@ -33,10 +33,6 @@ macro_rules! impl_type_for_tuple { $($T: Type,)* $($T: for<'a> Decode<'a, Postgres>,)* { - fn accepts(ty: &PgTypeInfo) -> bool { - accepts::(ty) - } - fn decode(value: PgValueRef<'r>) -> Result { #[allow(unused)] let mut decoder = PgRecordDecoder::new(value)?; diff --git a/sqlx-core/src/postgres/types/uuid.rs b/sqlx-core/src/postgres/types/uuid.rs index 926c1719..d309a7ee 100644 --- a/sqlx-core/src/postgres/types/uuid.rs +++ b/sqlx-core/src/postgres/types/uuid.rs @@ -1,6 +1,6 @@ use uuid::Uuid; -use crate::decode::{accepts, Decode}; +use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::postgres::{PgArgumentBuffer, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; @@ -33,10 +33,6 @@ impl Encode<'_, Postgres> for Uuid { } impl Decode<'_, Postgres> for Uuid { - fn accepts(ty: &PgTypeInfo) -> bool { - accepts::(ty) - } - fn decode(value: PgValueRef<'_>) -> Result { match value.format() { PgValueFormat::Binary => Uuid::from_slice(value.as_bytes()?), diff --git a/sqlx-core/src/row.rs b/sqlx-core/src/row.rs index 219ba635..42971346 100644 --- a/sqlx-core/src/row.rs +++ b/sqlx-core/src/row.rs @@ -3,6 +3,7 @@ use std::fmt::Debug; use crate::database::{Database, HasValueRef}; use crate::decode::Decode; use crate::error::{mismatched_types, Error}; +use crate::types::Type; use crate::value::ValueRef; /// A type that can be used to index into a [`Row`]. @@ -16,6 +17,7 @@ use crate::value::ValueRef; /// [`Row`]: trait.Row.html /// [`get`]: trait.Row.html#method.get /// [`try_get`]: trait.Row.html#method.try_get +/// pub trait ColumnIndex: private_column_index::Sealed + Debug { /// Returns a valid positional index into the row, [`ColumnIndexOutOfBounds`], or, /// [`ColumnNotFound`]. @@ -89,7 +91,7 @@ pub trait Row: private_row::Sealed + Unpin + Send + Sync + 'static { fn get<'r, T, I>(&'r self, index: I) -> T where I: ColumnIndex, - T: Decode<'r, Self::Database>, + T: Decode<'r, Self::Database> + Type, { self.try_get::(index).unwrap() } @@ -132,18 +134,18 @@ pub trait Row: private_row::Sealed + Unpin + Send + Sync + 'static { fn try_get<'r, T, I>(&'r self, index: I) -> Result where I: ColumnIndex, - T: Decode<'r, Self::Database>, + T: Decode<'r, Self::Database> + Type, { let value = self.try_get_raw(&index)?; if !value.is_null() { - if let Some(actual_ty) = value.type_info() { + if let Some(ty) = value.type_info() { // NOTE: we opt-out of asserting the type equivalency of NULL because of the // high false-positive rate (e.g., `NULL` in Postgres is `TEXT`). - if !T::accepts(&actual_ty) { + if !T::compatible(&ty) { return Err(Error::ColumnDecode { index: format!("{:?}", index), - source: mismatched_types::(&actual_ty), + source: mismatched_types::(&ty), }); } } diff --git a/sqlx-core/src/sqlite/type_info.rs b/sqlx-core/src/sqlite/type_info.rs index 26d8ab9c..21bad685 100644 --- a/sqlx-core/src/sqlite/type_info.rs +++ b/sqlx-core/src/sqlite/type_info.rs @@ -31,15 +31,13 @@ pub struct SqliteTypeInfo(pub(crate) DataType); impl Display for SqliteTypeInfo { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - self.0.fmt(f) + f.pad(self.name()) } } -impl TypeInfo for SqliteTypeInfo {} - -impl Display for DataType { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - f.write_str(match self { +impl TypeInfo for SqliteTypeInfo { + fn name(&self) -> &str { + match self.0 { DataType::Text => "TEXT", DataType::Float => "FLOAT", DataType::Blob => "BLOB", @@ -49,7 +47,7 @@ impl Display for DataType { // non-standard extensions DataType::Bool => "BOOLEAN", DataType::Int64 => "BIGINT", - }) + } } } diff --git a/sqlx-core/src/sqlite/types/bool.rs b/sqlx-core/src/sqlite/types/bool.rs index 9a64162c..c92389b7 100644 --- a/sqlx-core/src/sqlite/types/bool.rs +++ b/sqlx-core/src/sqlite/types/bool.rs @@ -20,10 +20,6 @@ impl<'q> Encode<'q, Sqlite> for bool { } impl<'r> Decode<'r, Sqlite> for bool { - fn accepts(_ty: &SqliteTypeInfo) -> bool { - true - } - fn decode(value: SqliteValueRef<'r>) -> Result { Ok(value.int() != 0) } diff --git a/sqlx-core/src/sqlite/types/bytes.rs b/sqlx-core/src/sqlite/types/bytes.rs index a79fbba1..f3c97296 100644 --- a/sqlx-core/src/sqlite/types/bytes.rs +++ b/sqlx-core/src/sqlite/types/bytes.rs @@ -22,10 +22,6 @@ impl<'q> Encode<'q, Sqlite> for &'q [u8] { } impl<'r> Decode<'r, Sqlite> for &'r [u8] { - fn accepts(_ty: &SqliteTypeInfo) -> bool { - true - } - fn decode(value: SqliteValueRef<'r>) -> Result { Ok(value.blob()) } @@ -52,10 +48,6 @@ impl<'q> Encode<'q, Sqlite> for Vec { } impl<'r> Decode<'r, Sqlite> for Vec { - fn accepts(_ty: &SqliteTypeInfo) -> bool { - true - } - fn decode(value: SqliteValueRef<'r>) -> Result { Ok(value.blob().to_owned()) } diff --git a/sqlx-core/src/sqlite/types/float.rs b/sqlx-core/src/sqlite/types/float.rs index c25bf0ce..d8b2c3bd 100644 --- a/sqlx-core/src/sqlite/types/float.rs +++ b/sqlx-core/src/sqlite/types/float.rs @@ -20,10 +20,6 @@ impl<'q> Encode<'q, Sqlite> for f32 { } impl<'r> Decode<'r, Sqlite> for f32 { - fn accepts(_ty: &SqliteTypeInfo) -> bool { - true - } - fn decode(value: SqliteValueRef<'r>) -> Result { Ok(value.double() as f32) } @@ -44,10 +40,6 @@ impl<'q> Encode<'q, Sqlite> for f64 { } impl<'r> Decode<'r, Sqlite> for f64 { - fn accepts(_ty: &SqliteTypeInfo) -> bool { - true - } - fn decode(value: SqliteValueRef<'r>) -> Result { Ok(value.double()) } diff --git a/sqlx-core/src/sqlite/types/int.rs b/sqlx-core/src/sqlite/types/int.rs index c4307637..a25e7762 100644 --- a/sqlx-core/src/sqlite/types/int.rs +++ b/sqlx-core/src/sqlite/types/int.rs @@ -20,10 +20,6 @@ impl<'q> Encode<'q, Sqlite> for i32 { } impl<'r> Decode<'r, Sqlite> for i32 { - fn accepts(_ty: &SqliteTypeInfo) -> bool { - true - } - fn decode(value: SqliteValueRef<'r>) -> Result { Ok(value.int()) } @@ -44,10 +40,6 @@ impl<'q> Encode<'q, Sqlite> for i64 { } impl<'r> Decode<'r, Sqlite> for i64 { - fn accepts(_ty: &SqliteTypeInfo) -> bool { - true - } - fn decode(value: SqliteValueRef<'r>) -> Result { Ok(value.int64()) } diff --git a/sqlx-core/src/sqlite/types/mod.rs b/sqlx-core/src/sqlite/types/mod.rs index a7b98cd9..4fe953c6 100644 --- a/sqlx-core/src/sqlite/types/mod.rs +++ b/sqlx-core/src/sqlite/types/mod.rs @@ -19,10 +19,6 @@ //! a potentially `NULL` value from SQLite. //! -// NOTE: all types are compatible with all other types in SQLite -// so we explicitly opt-out of runtime type assertions by returning [true] for -// all implementations of [Decode::accepts] - mod bool; mod bytes; mod float; diff --git a/sqlx-core/src/sqlite/types/str.rs b/sqlx-core/src/sqlite/types/str.rs index e0db2327..6a3ed533 100644 --- a/sqlx-core/src/sqlite/types/str.rs +++ b/sqlx-core/src/sqlite/types/str.rs @@ -22,10 +22,6 @@ impl<'q> Encode<'q, Sqlite> for &'q str { } impl<'r> Decode<'r, Sqlite> for &'r str { - fn accepts(_ty: &SqliteTypeInfo) -> bool { - true - } - fn decode(value: SqliteValueRef<'r>) -> Result { value.text() } @@ -52,10 +48,6 @@ impl<'q> Encode<'q, Sqlite> for String { } impl<'r> Decode<'r, Sqlite> for String { - fn accepts(_ty: &SqliteTypeInfo) -> bool { - true - } - fn decode(value: SqliteValueRef<'r>) -> Result { value.text().map(ToOwned::to_owned) } diff --git a/sqlx-core/src/type_info.rs b/sqlx-core/src/type_info.rs index 49fcc16b..2ce59b07 100644 --- a/sqlx-core/src/type_info.rs +++ b/sqlx-core/src/type_info.rs @@ -1,10 +1,6 @@ use std::fmt::{Debug, Display}; /// Provides information about a SQL type for the database driver. -/// -/// Currently this only exposes type equality rules that should roughly match the interpretation -/// in a given database (e.g., in PostgreSQL `VARCHAR` and `TEXT` are roughly equivalent -/// apart from storage). pub trait TypeInfo: Debug + Display + Clone + PartialEq { /// Returns the database system name of the type. Length specifiers should not be included. /// Common type names are `VARCHAR`, `TEXT`, or `INT`. Type names should be uppercase. They diff --git a/sqlx-core/src/types/json.rs b/sqlx-core/src/types/json.rs index 01d4634f..9aab1c62 100644 --- a/sqlx-core/src/types/json.rs +++ b/sqlx-core/src/types/json.rs @@ -34,6 +34,10 @@ where fn type_info() -> DB::TypeInfo { as Type>::type_info() } + + fn compatible(ty: &DB::TypeInfo) -> bool { + as Type>::compatible(ty) + } } impl<'q, DB> Encode<'q, DB> for JsonValue @@ -53,10 +57,6 @@ where Json: Decode<'r, DB>, DB: Database, { - fn accepts(ty: &DB::TypeInfo) -> bool { - as Decode>::accepts(ty) - } - fn decode(value: >::ValueRef) -> Result { as Decode>::decode(value).map(|item| item.0) } @@ -70,6 +70,10 @@ where fn type_info() -> DB::TypeInfo { as Type>::type_info() } + + fn compatible(ty: &DB::TypeInfo) -> bool { + as Type>::compatible(ty) + } } // We don't have to implement Encode for JsonRawValue because that's covered by the default @@ -80,10 +84,6 @@ where Json: Decode<'r, DB>, DB: Database, { - fn accepts(ty: &DB::TypeInfo) -> bool { - as Decode>::accepts(ty) - } - fn decode(value: >::ValueRef) -> Result { as Decode>::decode(value).map(|item| item.0) } diff --git a/sqlx-core/src/types/mod.rs b/sqlx-core/src/types/mod.rs index e6a9df0a..18e18423 100644 --- a/sqlx-core/src/types/mod.rs +++ b/sqlx-core/src/types/mod.rs @@ -48,22 +48,45 @@ pub use json::Json; /// Indicates that a SQL type is supported for a database. pub trait Type { - /// Returns the canonical type information on the database for the type `T`. + /// Returns the canonical SQL type for this Rust type. + /// + /// When binding arguments, this is used to tell the database what is about to be sent; which, + /// the database then uses to guide query plans. This can be overridden by `Encode::produces`. + /// + /// A map of SQL types to Rust types is populated with this and used + /// to determine the type that is returned from the anonymous struct type from `query!`. fn type_info() -> DB::TypeInfo; + + /// Determines if this Rust type is compatible with the given SQL type. + /// + /// When decoding values from a row, this method is checked to determine if we should continue + /// or raise a runtime type mismatch error. + /// + /// When binding arguments with `query!` or `query_as!`, this method is consulted to determine + /// if the Rust type is acceptable. + fn compatible(ty: &DB::TypeInfo) -> bool { + *ty == Self::type_info() + } } // for references, the underlying SQL type is identical impl, DB: Database> Type for &'_ T { - #[inline] fn type_info() -> DB::TypeInfo { >::type_info() } + + fn compatible(ty: &DB::TypeInfo) -> bool { + >::compatible(ty) + } } // for optionals, the underlying SQL type is identical impl, DB: Database> Type for Option { - #[inline] fn type_info() -> DB::TypeInfo { >::type_info() } + + fn compatible(ty: &DB::TypeInfo) -> bool { + >::compatible(ty) + } } diff --git a/sqlx-core/src/value.rs b/sqlx-core/src/value.rs index ec115109..1ef4c3b7 100644 --- a/sqlx-core/src/value.rs +++ b/sqlx-core/src/value.rs @@ -3,6 +3,7 @@ use std::borrow::Cow; use crate::database::{Database, HasValueRef}; use crate::decode::Decode; use crate::error::{mismatched_types, Error}; +use crate::types::Type; /// An owned value from the database. pub trait Value { @@ -30,7 +31,7 @@ pub trait Value { #[inline] fn decode<'r, T>(&'r self) -> T where - T: Decode<'r, Self::Database>, + T: Decode<'r, Self::Database> + Type, { self.try_decode::().unwrap() } @@ -64,14 +65,12 @@ pub trait Value { #[inline] fn try_decode<'r, T>(&'r self) -> Result where - T: Decode<'r, Self::Database>, + T: Decode<'r, Self::Database> + Type, { if !self.is_null() { - if let Some(actual_ty) = self.type_info() { - if !T::accepts(&actual_ty) { - return Err(Error::Decode(mismatched_types::( - &actual_ty, - ))); + if let Some(ty) = self.type_info() { + if !T::compatible(&ty) { + return Err(Error::Decode(mismatched_types::(&ty))); } } } diff --git a/sqlx-macros/src/database/mod.rs b/sqlx-macros/src/database/mod.rs index c228700d..cbc0c5dd 100644 --- a/sqlx-macros/src/database/mod.rs +++ b/sqlx-macros/src/database/mod.rs @@ -53,7 +53,7 @@ macro_rules! impl_database_ext { )* $( $(#[$meta])? - _ if <$ty as sqlx_core::decode::Decode<$database>>::accepts(&info) => Some(input_ty!($ty $(, $input)?)), + _ if <$ty as sqlx_core::types::Type<$database>>::compatible(info) => Some(input_ty!($ty $(, $input)?)), )* _ => None } @@ -67,7 +67,7 @@ macro_rules! impl_database_ext { )* $( $(#[$meta])? - _ if <$ty as sqlx_core::decode::Decode<$database>>::accepts(&info) => return Some(stringify!($ty)), + _ if <$ty as sqlx_core::types::Type<$database>>::compatible(info) => return Some(stringify!($ty)), )* _ => None } diff --git a/sqlx-macros/src/derives/decode.rs b/sqlx-macros/src/derives/decode.rs index a3cd102c..e3c5cc78 100644 --- a/sqlx-macros/src/derives/decode.rs +++ b/sqlx-macros/src/derives/decode.rs @@ -71,10 +71,6 @@ fn expand_derive_decode_transparent( let tts = quote!( impl #impl_generics sqlx::decode::Decode<'r, DB> for #ident #ty_generics #where_clause { - fn accepts(ty: &DB::TypeInfo) -> bool { - <#ty as sqlx::decode::Decode<'r, DB>>::accepts(ty) - } - fn decode(value: >::ValueRef) -> std::result::Result> { <#ty as sqlx::decode::Decode<'r, DB>>::decode(value).map(Self) } @@ -104,10 +100,6 @@ fn expand_derive_decode_weak_enum( Ok(quote!( impl<'r, DB: sqlx::Database> sqlx::decode::Decode<'r, DB> for #ident where #repr: sqlx::decode::Decode<'r, DB> { - fn accepts(ty: &DB::TypeInfo) -> bool { - <#repr as sqlx::decode::Decode<'r, DB>>::accepts(ty) - } - fn decode(value: >::ValueRef) -> std::result::Result> { let value = <#repr as sqlx::decode::Decode<'r, DB>>::decode(value)?; @@ -159,10 +151,6 @@ fn expand_derive_decode_strong_enum( if cfg!(feature = "mysql") { tts.extend(quote!( impl<'r> sqlx::decode::Decode<'r, sqlx::mysql::MySql> for #ident { - fn accepts(ty: &sqlx::mysql::MySqlTypeInfo) -> bool { - ty == sqlx::mysql::MySqlTypeInfo::__enum() - } - fn decode(value: sqlx::mysql::MySqlValueRef<'r>) -> std::result::Result> { let value = <&'r str as sqlx::decode::Decode<'r, sqlx::mysql::MySql>>::decode(value)?; @@ -175,10 +163,6 @@ fn expand_derive_decode_strong_enum( if cfg!(feature = "postgres") { tts.extend(quote!( impl<'r> sqlx::decode::Decode<'r, sqlx::postgres::Postgres> for #ident { - fn accepts(ty: &sqlx::postgres::PgTypeInfo) -> bool { - *ty == <#ident as sqlx::Type>::type_info() - } - fn decode(value: sqlx::postgres::PgValueRef<'r>) -> std::result::Result> { let value = <&'r str as sqlx::decode::Decode<'r, sqlx::postgres::Postgres>>::decode(value)?; @@ -191,10 +175,6 @@ fn expand_derive_decode_strong_enum( if cfg!(feature = "sqlite") { tts.extend(quote!( impl<'r> sqlx::decode::Decode<'r, sqlx::sqlite::Sqlite> for #ident { - fn accepts(ty: &sqlx::sqlite::SqliteTypeInfo) -> bool { - <&str as sqlx::decode::Decode<'r, DB>>::accepts(ty) - } - fn decode(value: sqlx::sqlite::SqliteValueRef<'r>) -> std::result::Result> { let value = <&'r str as sqlx::decode::Decode<'r, sqlx::sqlite::Sqlite>>::decode(value)?; @@ -250,10 +230,6 @@ fn expand_derive_decode_struct( tts.extend(quote!( impl #impl_generics sqlx::decode::Decode<'r, sqlx::Postgres> for #ident #ty_generics #where_clause { - fn accepts(ty: &sqlx::postgres::PgTypeInfo) -> bool { - *ty == >::type_info() - } - fn decode(value: sqlx::postgres::PgValueRef<'r>) -> std::result::Result> { let mut decoder = sqlx::postgres::types::PgRecordDecoder::new(value)?; diff --git a/sqlx-macros/src/derives/type.rs b/sqlx-macros/src/derives/type.rs index fa8a9a30..ad41dba5 100644 --- a/sqlx-macros/src/derives/type.rs +++ b/sqlx-macros/src/derives/type.rs @@ -129,6 +129,10 @@ fn expand_derive_has_sql_type_strong_enum( fn type_info() -> sqlx::mysql::MySqlTypeInfo { sqlx::mysql::MySqlTypeInfo::__enum() } + + fn compatible(ty: &sqlx::mysql::MySqlTypeInfo) -> bool { + ty == sqlx::mysql::MySqlTypeInfo::__enum() + } } )); } @@ -151,6 +155,10 @@ fn expand_derive_has_sql_type_strong_enum( fn type_info() -> sqlx::sqlite::SqliteTypeInfo { >::type_info() } + + fn compatible(ty: &sqlx::sqlite::SqliteTypeInfo) -> bool { + <&str as sqlx::types::Type>::compatible(ty) + } } )); } diff --git a/src/lib.rs b/src/lib.rs index b22005a6..14e63f01 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -77,7 +77,7 @@ pub mod encode { /// Provides [`Decode`](decode/trait.Decode.html) for decoding values from the database. pub mod decode { - pub use sqlx_core::decode::{Decode, Result}; + pub use sqlx_core::decode::Decode; #[cfg(feature = "macros")] #[doc(hidden)]