diff --git a/sqlx-core/src/database.rs b/sqlx-core/src/database.rs index b2033bd3..559d5d1e 100644 --- a/sqlx-core/src/database.rs +++ b/sqlx-core/src/database.rs @@ -7,7 +7,8 @@ use crate::{Column, Connection, QueryResult, Row, Runtime}; /// This trait encapsulates a complete set of traits that implement a driver for a /// specific database (e.g., MySQL, PostgreSQL). /// -pub trait Database: 'static + Sized + Debug + for<'x> HasOutput<'x> +pub trait Database: + 'static + Sized + Debug + for<'x> HasOutput<'x> + for<'r> HasRawValue<'r> where Rt: Runtime, { @@ -22,12 +23,24 @@ where /// The concrete [`QueryResult`] implementation for this database. type QueryResult: QueryResult; + + /// The concrete [`TypeInfo`] implementation for this database. + type TypeInfo; + + /// The concrete [`TypeId`] implementation for this database. + type TypeId; } -/// Associates [`Database`] with a `Output` of a generic lifetime. -/// 'x: single execution +/// Associates [`Database`] with an `Output` of a generic lifetime. +// 'x: single execution pub trait HasOutput<'x> { - /// The concrete type to hold the output for `Encode` for this database. This may be - /// a simple alias to `&'x mut Vec`. + /// The concrete type to hold the output for [`Encode`] for this database. type Output; } + +/// Associates [`Database`] with a `RawValue` of a generic lifetime. +// 'r: row +pub trait HasRawValue<'r> { + /// The concrete type to hold the input for [`Decode`] for this database. + type RawValue; +} diff --git a/sqlx-core/src/decode.rs b/sqlx-core/src/decode.rs index 068568bc..0f934ae0 100644 --- a/sqlx-core/src/decode.rs +++ b/sqlx-core/src/decode.rs @@ -1,5 +1,63 @@ +use std::error::Error as StdError; +use std::fmt::{self, Display, Formatter}; +use std::str::Utf8Error; + +use crate::database::HasRawValue; use crate::{Database, Runtime}; -pub trait Decode>: Sized { - fn decode(raw: &[u8]) -> crate::Result; +/// A type that can be decoded from a SQL value. +pub trait Decode<'r, Db: Database, Rt: Runtime>: Sized + Send + Sync { + fn decode(value: >::RawValue) -> Result; } + +/// A type that can be decoded from a SQL value, without borrowing any data +/// from the row. +pub trait DecodeOwned, Rt: Runtime>: for<'de> Decode<'de, Db, Rt> {} + +impl, Rt: Runtime> DecodeOwned for T where + T: for<'de> Decode<'de, Db, Rt> +{ +} + +/// Errors which can occur while decoding a SQL value. +#[derive(Debug)] +#[non_exhaustive] +pub enum Error { + /// An unexpected SQL `NULL` was encountered during decoding. + /// + /// To decode potentially `NULL` values, wrap the target type in `Option`. + /// + UnexpectedNull, + + /// Attempted to decode non-UTF-8 data into a Rust `str`. + NotUtf8(Utf8Error), + + /// A general error raised while decoding a value. + Custom(Box), +} + +impl Display for Error { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Self::UnexpectedNull => f.write_str("unexpected null; try decoding as an `Option`"), + + Self::NotUtf8(error) => { + write!(f, "{}", error) + } + + Self::Custom(error) => { + write!(f, "{}", error) + } + } + } +} + +// noinspection DuplicatedCode +impl From for Error { + fn from(error: E) -> Self { + Self::Custom(Box::new(error)) + } +} + +/// A specialized result type representing the result of decoding a SQL value. +pub type Result = std::result::Result; diff --git a/sqlx-core/src/encode.rs b/sqlx-core/src/encode.rs new file mode 100644 index 00000000..274b4eae --- /dev/null +++ b/sqlx-core/src/encode.rs @@ -0,0 +1,52 @@ +use std::error::Error as StdError; +use std::fmt::{self, Display, Formatter}; + +use crate::database::{HasOutput, HasRawValue}; +use crate::{Database, Runtime}; + +/// A type that can be encoded into a SQL value. +pub trait Encode, Rt: Runtime>: Send + Sync { + /// Encode this value into a SQL value. + fn encode(&self, ty: &Db::TypeInfo, out: &mut >::Output) -> Result<()>; + + #[doc(hidden)] + #[inline] + fn __type_name(&self) -> &'static str { + std::any::type_name::() + } +} + +impl, Db: Database, Rt: Runtime> Encode for &T { + #[inline] + fn encode(&self, ty: &Db::TypeInfo, out: &mut >::Output) -> Result<()> { + (*self).encode(ty, out) + } +} + +/// Errors which can occur while encoding a SQL value. +#[derive(Debug)] +#[non_exhaustive] +pub enum Error { + /// A general error raised while encoding a value. + Custom(Box), +} + +impl Display for Error { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Self::Custom(error) => { + write!(f, "{}", error) + } + } + } +} + +// noinspection DuplicatedCode +impl From for Error { + fn from(error: E) -> Self { + Self::Custom(Box::new(error)) + } +} + +/// A specialized result type representing the result of encoding a SQL value. +pub type Result = std::result::Result; diff --git a/sqlx-core/src/error.rs b/sqlx-core/src/error.rs index 249161dc..a399bf76 100644 --- a/sqlx-core/src/error.rs +++ b/sqlx-core/src/error.rs @@ -2,6 +2,9 @@ use std::borrow::Cow; use std::error::Error as StdError; use std::fmt::{self, Display, Formatter}; +use crate::decode::Error as DecodeError; +use crate::encode::Error as EncodeError; + mod database; pub use database::DatabaseError; @@ -26,6 +29,10 @@ pub enum Error { /// Use `fetch_optional` to return `None` instead of signaling an error. /// RowNotFound, + + Decode(DecodeError), + + Encode(EncodeError), } impl Error { @@ -69,6 +76,14 @@ impl Display for Error { Self::RowNotFound => { f.write_str("no row returned by a query required to return at least one row") } + + Self::Decode(error) => { + write!(f, "{}", error) + } + + Self::Encode(error) => { + write!(f, "{}", error) + } } } } @@ -96,3 +111,15 @@ impl From for Error { Self::Network(error.into()) } } + +impl From for Error { + fn from(error: DecodeError) -> Self { + Self::Decode(error) + } +} + +impl From for Error { + fn from(error: EncodeError) -> Self { + Self::Encode(error) + } +} diff --git a/sqlx-core/src/lib.rs b/sqlx-core/src/lib.rs index 586f471d..58733c00 100644 --- a/sqlx-core/src/lib.rs +++ b/sqlx-core/src/lib.rs @@ -20,18 +20,19 @@ mod acquire; mod close; +mod column; mod connect; mod connection; -mod database; +pub mod database; +pub mod decode; +pub mod encode; mod error; mod executor; mod options; mod pool; -mod runtime; -mod decode; -mod row; mod query_result; -mod column; +mod row; +mod runtime; #[doc(hidden)] pub mod io; @@ -50,16 +51,18 @@ pub use acquire::Acquire; #[cfg(feature = "blocking")] pub use blocking::runtime::Blocking; pub use close::Close; -pub use connect::Connect; pub use column::Column; +pub use connect::Connect; pub use connection::Connection; -pub use database::{Database, HasOutput}; +pub use database::Database; +pub use decode::Decode; +pub use encode::Encode; pub use error::{DatabaseError, Error, Result}; pub use executor::Executor; -pub use query_result::QueryResult; -pub use row::Row; pub use options::ConnectOptions; pub use pool::Pool; +pub use query_result::QueryResult; +pub use row::Row; #[cfg(feature = "actix")] pub use runtime::Actix; #[cfg(feature = "async")] diff --git a/sqlx-mysql/src/database.rs b/sqlx-mysql/src/database.rs index 85c7cd3e..8b6f3ad8 100644 --- a/sqlx-mysql/src/database.rs +++ b/sqlx-mysql/src/database.rs @@ -1,6 +1,10 @@ -use sqlx_core::{Database, HasOutput, Runtime}; +use sqlx_core::database::{HasOutput, HasRawValue}; +use sqlx_core::{Database, Runtime}; -use super::{MySqlConnection, MySqlRow, MySqlColumn, MySqlQueryResult}; +use super::{ + MySqlColumn, MySqlConnection, MySqlOutput, MySqlQueryResult, MySqlRawValue, MySqlRow, + MySqlTypeId, MySqlTypeInfo, +}; #[derive(Debug)] pub struct MySql; @@ -8,13 +12,21 @@ pub struct MySql; impl Database for MySql { type Connection = MySqlConnection; - type Row = MySqlRow; - type Column = MySqlColumn; + type Row = MySqlRow; + type QueryResult = MySqlQueryResult; + + type TypeId = MySqlTypeId; + + type TypeInfo = MySqlTypeInfo; } impl<'x> HasOutput<'x> for MySql { - type Output = &'x mut Vec; + type Output = MySqlOutput<'x>; +} + +impl<'r> HasRawValue<'r> for MySql { + type RawValue = MySqlRawValue<'r>; }