diff --git a/sqlx-core/src/mysql/database.rs b/sqlx-core/src/mysql/database.rs index 5de50304..0a5d4e19 100644 --- a/sqlx-core/src/mysql/database.rs +++ b/sqlx-core/src/mysql/database.rs @@ -1,5 +1,8 @@ -use crate::database::{Database, HasCursor, HasRawValue, HasRow}; +use crate::cursor::HasCursor; +use crate::database::Database; use crate::mysql::error::MySqlError; +use crate::row::HasRow; +use crate::value::HasRawValue; /// **MySQL** database driver. #[derive(Debug)] @@ -32,5 +35,7 @@ impl<'c, 'q> HasCursor<'c, 'q> for MySql { } impl<'c> HasRawValue<'c> for MySql { + type Database = MySql; + type RawValue = Option>; } diff --git a/sqlx-core/src/postgres/connection.rs b/sqlx-core/src/postgres/connection.rs index 480ad895..87cfe61d 100644 --- a/sqlx-core/src/postgres/connection.rs +++ b/sqlx-core/src/postgres/connection.rs @@ -11,8 +11,9 @@ use crate::executor::Executor; use crate::postgres::database::Postgres; use crate::postgres::protocol::{ Authentication, AuthenticationMd5, AuthenticationSasl, BackendKeyData, Message, - PasswordMessage, StartupMessage, StatementId, Terminate, TypeFormat, + PasswordMessage, StartupMessage, StatementId, Terminate, }; +use crate::postgres::row::Statement; use crate::postgres::stream::PgStream; use crate::postgres::{sasl, tls}; use crate::url::Url; @@ -82,9 +83,11 @@ pub struct PgConnection { pub(super) next_statement_id: u32, pub(super) is_ready: bool, - pub(super) cache_statement: HashMap, StatementId>, - pub(super) cache_statement_columns: HashMap, usize>>>, - pub(super) cache_statement_formats: HashMap>, + // cache query -> statement ID + pub(super) cache_statement_id: HashMap, StatementId>, + + // cache statement ID -> statement description + pub(super) cache_statement: HashMap>, // Work buffer for the value ranges of the current row // This is used as the backing memory for each Row's value indexes @@ -250,9 +253,8 @@ impl PgConnection { current_row_values: Vec::with_capacity(10), next_statement_id: 1, is_ready: true, - cache_statement: HashMap::new(), - cache_statement_columns: HashMap::new(), - cache_statement_formats: HashMap::new(), + cache_statement_id: HashMap::with_capacity(10), + cache_statement: HashMap::with_capacity(10), process_id: key_data.process_id, secret_key: key_data.secret_key, }) diff --git a/sqlx-core/src/postgres/cursor.rs b/sqlx-core/src/postgres/cursor.rs index 7cf01d4b..cf2f3bd4 100644 --- a/sqlx-core/src/postgres/cursor.rs +++ b/sqlx-core/src/postgres/cursor.rs @@ -7,16 +7,14 @@ use crate::connection::ConnectionSource; use crate::cursor::Cursor; use crate::executor::Execute; use crate::pool::Pool; -use crate::postgres::protocol::{ - DataRow, Message, ReadyForQuery, RowDescription, StatementId, TypeFormat, -}; +use crate::postgres::protocol::{DataRow, Message, ReadyForQuery, RowDescription, StatementId}; +use crate::postgres::row::{Column, Statement}; use crate::postgres::{PgArguments, PgConnection, PgRow, Postgres}; pub struct PgCursor<'c, 'q> { source: ConnectionSource<'c, PgConnection>, query: Option<(&'q str, Option)>, - columns: Arc, usize>>, - formats: Arc<[TypeFormat]>, + statement: Arc, } impl crate::cursor::private::Sealed for PgCursor<'_, '_> {} @@ -32,8 +30,7 @@ impl<'c, 'q> Cursor<'c, 'q> for PgCursor<'c, 'q> { { Self { source: ConnectionSource::Pool(pool.clone()), - columns: Arc::default(), - formats: Arc::new([] as [TypeFormat; 0]), + statement: Arc::default(), query: Some(query.into_parts()), } } @@ -46,8 +43,7 @@ impl<'c, 'q> Cursor<'c, 'q> for PgCursor<'c, 'q> { { Self { source: ConnectionSource::ConnectionRef(conn), - columns: Arc::default(), - formats: Arc::new([] as [TypeFormat; 0]), + statement: Arc::default(), query: Some(query.into_parts()), } } @@ -57,29 +53,33 @@ impl<'c, 'q> Cursor<'c, 'q> for PgCursor<'c, 'q> { } } -fn parse_row_description(rd: RowDescription) -> (HashMap, usize>, Vec) { - let mut columns = HashMap::new(); - let mut formats = Vec::new(); +fn parse_row_description(rd: RowDescription) -> Statement { + let mut names = HashMap::new(); + let mut columns = Vec::new(); columns.reserve(rd.fields.len()); - formats.reserve(rd.fields.len()); + names.reserve(rd.fields.len()); for (index, field) in rd.fields.iter().enumerate() { if let Some(name) = &field.name { - columns.insert(name.clone(), index); + names.insert(name.clone(), index); } - formats.push(field.type_format); + columns.push(Column { + type_oid: field.type_id.0, + format: field.type_format, + }); } - (columns, formats) + Statement { + columns: columns.into_boxed_slice(), + names, + } } // Used to describe the incoming results // We store the column map in an Arc and share it among all rows -async fn expect_desc( - conn: &mut PgConnection, -) -> crate::Result, usize>, Vec)> { +async fn expect_desc(conn: &mut PgConnection) -> crate::Result { let description: Option<_> = loop { match conn.stream.receive().await? { Message::ParseComplete | Message::BindComplete => {} @@ -106,24 +106,15 @@ async fn expect_desc( // A form of describe that uses the statement cache async fn get_or_describe( conn: &mut PgConnection, - statement: StatementId, -) -> crate::Result, usize>>, Arc<[TypeFormat]>)> { - if !conn.cache_statement_columns.contains_key(&statement) - || !conn.cache_statement_formats.contains_key(&statement) - { - let (columns, formats) = expect_desc(conn).await?; + id: StatementId, +) -> crate::Result> { + if !conn.cache_statement.contains_key(&id) { + let statement = expect_desc(conn).await?; - conn.cache_statement_columns - .insert(statement, Arc::new(columns)); - - conn.cache_statement_formats - .insert(statement, Arc::from(formats)); + conn.cache_statement.insert(id, Arc::new(statement)); } - Ok(( - Arc::clone(&conn.cache_statement_columns[&statement]), - Arc::clone(&conn.cache_statement_formats[&statement]), - )) + Ok(Arc::clone(&conn.cache_statement[&id])) } async fn next<'a, 'c: 'a, 'q: 'a>( @@ -141,10 +132,7 @@ async fn next<'a, 'c: 'a, 'q: 'a>( if let Some(statement) = statement { // A prepared statement will re-use the previous column map if // this query has been executed before - let (columns, formats) = get_or_describe(&mut *conn, statement).await?; - - cursor.columns = columns; - cursor.formats = formats; + cursor.statement = get_or_describe(&mut *conn, statement).await?; } // A non-prepared query must be described each time @@ -171,18 +159,14 @@ async fn next<'a, 'c: 'a, 'q: 'a>( Message::RowDescription => { let rd = RowDescription::read(conn.stream.buffer())?; - let (columns, formats) = parse_row_description(rd); - - cursor.columns = Arc::new(columns); - cursor.formats = Arc::from(formats); + cursor.statement = Arc::new(parse_row_description(rd)); } Message::DataRow => { let data = DataRow::read(conn.stream.buffer(), &mut conn.current_row_values)?; return Ok(Some(PgRow { - columns: Arc::clone(&cursor.columns), - formats: Arc::clone(&cursor.formats), + statement: Arc::clone(&cursor.statement), data, })); } diff --git a/sqlx-core/src/postgres/database.rs b/sqlx-core/src/postgres/database.rs index 56a34a2b..5deabcba 100644 --- a/sqlx-core/src/postgres/database.rs +++ b/sqlx-core/src/postgres/database.rs @@ -1,19 +1,21 @@ //! Types which represent various database drivers. -use crate::database::{Database, HasCursor, HasRawValue, HasRow}; -use crate::postgres::error::PgError; -use crate::postgres::row::PgValue; +use crate::cursor::HasCursor; +use crate::database::Database; +use crate::postgres::{PgArguments, PgConnection, PgCursor, PgError, PgRow, PgTypeInfo, PgValue}; +use crate::row::HasRow; +use crate::value::HasRawValue; /// **Postgres** database driver. #[derive(Debug)] pub struct Postgres; impl Database for Postgres { - type Connection = super::PgConnection; + type Connection = PgConnection; - type Arguments = super::PgArguments; + type Arguments = PgArguments; - type TypeInfo = super::PgTypeInfo; + type TypeInfo = PgTypeInfo; type TableId = u32; @@ -25,15 +27,17 @@ impl Database for Postgres { impl<'a> HasRow<'a> for Postgres { type Database = Postgres; - type Row = super::PgRow<'a>; + type Row = PgRow<'a>; } impl<'s, 'q> HasCursor<'s, 'q> for Postgres { type Database = Postgres; - type Cursor = super::PgCursor<'s, 'q>; + type Cursor = PgCursor<'s, 'q>; } impl<'a> HasRawValue<'a> for Postgres { - type RawValue = Option>; + type Database = Postgres; + + type RawValue = PgValue<'a>; } diff --git a/sqlx-core/src/postgres/executor.rs b/sqlx-core/src/postgres/executor.rs index f165f8bb..1bcb3bba 100644 --- a/sqlx-core/src/postgres/executor.rs +++ b/sqlx-core/src/postgres/executor.rs @@ -22,7 +22,7 @@ impl PgConnection { } pub(crate) fn write_prepare(&mut self, query: &str, args: &PgArguments) -> StatementId { - if let Some(&id) = self.cache_statement.get(query) { + if let Some(&id) = self.cache_statement_id.get(query) { id } else { let id = StatementId(self.next_statement_id); @@ -35,7 +35,7 @@ impl PgConnection { param_types: &*args.types, }); - self.cache_statement.insert(query.into(), id); + self.cache_statement_id.insert(query.into(), id); id } @@ -106,7 +106,7 @@ impl PgConnection { // Next, [Describe] will return the expected result columns and types // Conditionally run [Describe] only if the results have not been cached - if !self.cache_statement_columns.contains_key(&statement) { + if !self.cache_statement.contains_key(&statement) { self.write_describe(protocol::Describe::Portal("")); } diff --git a/sqlx-core/src/postgres/mod.rs b/sqlx-core/src/postgres/mod.rs index a98b4f02..dab25ab8 100644 --- a/sqlx-core/src/postgres/mod.rs +++ b/sqlx-core/src/postgres/mod.rs @@ -6,8 +6,9 @@ pub use cursor::PgCursor; pub use database::Postgres; pub use error::PgError; pub use listen::{PgListener, PgNotification}; -pub use row::{PgRow, PgValue}; +pub use row::PgRow; pub use types::PgTypeInfo; +pub use value::{PgData, PgValue}; mod arguments; mod connection; @@ -22,6 +23,7 @@ mod sasl; mod stream; mod tls; pub mod types; +mod value; /// An alias for [`Pool`][crate::pool::Pool], specialized for **Postgres**. #[cfg_attr(docsrs, doc(cfg(feature = "postgres")))] @@ -29,5 +31,4 @@ pub type PgPool = crate::pool::Pool; make_query_as!(PgQueryAs, Postgres, PgRow); impl_map_row_for_row!(Postgres, PgRow); -impl_column_index_for_row!(PgRow); impl_from_row_for_tuples!(Postgres, PgRow); diff --git a/sqlx-core/src/postgres/protocol/data_row.rs b/sqlx-core/src/postgres/protocol/data_row.rs index ba4bb36a..37d2548a 100644 --- a/sqlx-core/src/postgres/protocol/data_row.rs +++ b/sqlx-core/src/postgres/protocol/data_row.rs @@ -26,9 +26,6 @@ impl<'c> DataRow<'c> { buffer: &'c [u8], values: &'c mut Vec>>, ) -> crate::Result { - // let buffer = connection.stream.buffer(); - // let values = &mut connection.current_row_values; - values.clear(); let mut buf = buffer; diff --git a/sqlx-core/src/postgres/row.rs b/sqlx-core/src/postgres/row.rs index 25261b72..55509efc 100644 --- a/sqlx-core/src/postgres/row.rs +++ b/sqlx-core/src/postgres/row.rs @@ -1,38 +1,37 @@ -use core::str::{from_utf8, Utf8Error}; - use std::collections::HashMap; -use std::convert::TryFrom; use std::sync::Arc; -use crate::error::UnexpectedNullError; use crate::postgres::protocol::{DataRow, TypeFormat}; +use crate::postgres::value::PgValue; use crate::postgres::Postgres; use crate::row::{ColumnIndex, Row}; -/// A value from Postgres. This may be in a BINARY or TEXT format depending -/// on the data type and if the query was prepared or not. -#[derive(Debug)] -pub enum PgValue<'c> { - Binary(&'c [u8]), - Text(&'c str), +// A statement has 0 or more columns being returned from the database +// For Postgres, each column has an OID and a format (binary or text) +// For simple (unprepared) queries, format will always be text +// For prepared queries, format will _almost_ always be binary +pub(crate) struct Column { + pub(crate) type_oid: u32, + pub(crate) format: TypeFormat, } -impl<'c> TryFrom>> for PgValue<'c> { - type Error = crate::Error; +// A statement description containing the column information used to +// properly decode data +#[derive(Default)] +pub(crate) struct Statement { + // column name -> position + pub(crate) names: HashMap, usize>, - #[inline] - fn try_from(value: Option>) -> Result { - match value { - Some(value) => Ok(value), - None => Err(crate::Error::decode(UnexpectedNullError)), - } - } + // all columns + pub(crate) columns: Box<[Column]>, } pub struct PgRow<'c> { pub(super) data: DataRow<'c>, - pub(super) columns: Arc, usize>>, - pub(super) formats: Arc<[TypeFormat]>, + + // shared reference to the statement this row is coming from + // allows us to get the column information on demand + pub(super) statement: Arc, } impl crate::row::private_row::Sealed for PgRow<'_> {} @@ -40,24 +39,47 @@ impl crate::row::private_row::Sealed for PgRow<'_> {} impl<'c> Row<'c> for PgRow<'c> { type Database = Postgres; + #[inline] fn len(&self) -> usize { self.data.len() } #[doc(hidden)] - fn try_get_raw(&self, index: I) -> crate::Result>> + fn try_get_raw(&self, index: I) -> crate::Result> where I: ColumnIndex<'c, Self>, { let index = index.index(self)?; + let column = &self.statement.columns[index]; let buffer = self.data.get(index); + let value = match (column.format, buffer) { + (_, None) => PgValue::null(column.type_oid), + (TypeFormat::Binary, Some(buf)) => PgValue::bytes(column.type_oid, buf), + (TypeFormat::Text, Some(buf)) => PgValue::utf8(column.type_oid, buf)?, + }; - buffer - .map(|buf| match self.formats[index] { - TypeFormat::Binary => Ok(PgValue::Binary(buf)), - TypeFormat::Text => Ok(PgValue::Text(from_utf8(buf)?)), - }) - .transpose() - .map_err(|err: Utf8Error| crate::Error::Decode(Box::new(err))) + Ok(value) + } +} + +impl<'c> ColumnIndex<'c, PgRow<'c>> for usize { + fn index(&self, row: &PgRow<'c>) -> crate::Result { + let len = Row::len(row); + + if *self >= len { + return Err(crate::Error::ColumnIndexOutOfBounds { len, index: *self }); + } + + Ok(*self) + } +} + +impl<'c> ColumnIndex<'c, PgRow<'c>> for str { + fn index(&self, row: &PgRow<'c>) -> crate::Result { + row.statement + .names + .get(self) + .ok_or_else(|| crate::Error::ColumnNotFound((*self).into())) + .map(|&index| index as usize) } } diff --git a/sqlx-core/src/postgres/types/array.rs b/sqlx-core/src/postgres/types/array.rs index f7203553..845a5274 100644 --- a/sqlx-core/src/postgres/types/array.rs +++ b/sqlx-core/src/postgres/types/array.rs @@ -41,7 +41,7 @@ where [T]: Type, T: Type, { - fn decode(value: Option>) -> crate::Result { + fn decode(value: PgValue<'de>) -> crate::Result { PgArrayDecoder::::new(value)?.collect() } } diff --git a/sqlx-core/src/postgres/types/bigdecimal.rs b/sqlx-core/src/postgres/types/bigdecimal.rs index 3061deb8..794655ed 100644 --- a/sqlx-core/src/postgres/types/bigdecimal.rs +++ b/sqlx-core/src/postgres/types/bigdecimal.rs @@ -6,7 +6,7 @@ use num_bigint::{BigInt, Sign}; use crate::decode::Decode; use crate::encode::Encode; -use crate::postgres::{PgTypeInfo, PgValue, Postgres}; +use crate::postgres::{PgData, PgTypeInfo, PgValue, Postgres}; use crate::types::Type; use super::raw::{PgNumeric, PgNumericSign}; @@ -152,10 +152,10 @@ impl Encode for BigDecimal { } impl Decode<'_, Postgres> for BigDecimal { - fn decode(value: Option) -> crate::Result { - match value.try_into()? { - PgValue::Binary(binary) => PgNumeric::from_bytes(binary)?.try_into(), - PgValue::Text(text) => text + fn decode(value: PgValue) -> crate::Result { + match value.try_get()? { + PgData::Binary(binary) => PgNumeric::from_bytes(binary)?.try_into(), + PgData::Text(text) => text .parse::() .map_err(|e| crate::Error::Decode(e.into())), } diff --git a/sqlx-core/src/postgres/types/bool.rs b/sqlx-core/src/postgres/types/bool.rs index f4336ccb..2c0f2433 100644 --- a/sqlx-core/src/postgres/types/bool.rs +++ b/sqlx-core/src/postgres/types/bool.rs @@ -1,11 +1,7 @@ -use std::convert::TryInto; - use crate::decode::Decode; use crate::encode::Encode; use crate::postgres::protocol::TypeId; -use crate::postgres::row::PgValue; -use crate::postgres::types::PgTypeInfo; -use crate::postgres::Postgres; +use crate::postgres::{PgData, PgTypeInfo, PgValue, Postgres}; use crate::types::Type; impl Type for bool { @@ -32,18 +28,14 @@ impl Encode for bool { } impl<'de> Decode<'de, Postgres> for bool { - fn decode(value: Option>) -> crate::Result { - match value.try_into()? { - PgValue::Binary(buf) => Ok(buf.get(0).map(|&b| b != 0).unwrap_or_default()), + fn decode(value: PgValue<'de>) -> crate::Result { + match value.try_get()? { + PgData::Binary(buf) => Ok(buf.get(0).map(|&b| b != 0).unwrap_or_default()), - PgValue::Text("t") => Ok(true), - PgValue::Text("f") => Ok(false), + PgData::Text("t") => Ok(true), + PgData::Text("f") => Ok(false), - PgValue::Text(s) => { - return Err(crate::Error::Decode( - format!("unexpected value {:?} for boolean", s).into(), - )); - } + PgData::Text(s) => Err(decode_err!("unexpected value {:?} for boolean", s)), } } } diff --git a/sqlx-core/src/postgres/types/bytes.rs b/sqlx-core/src/postgres/types/bytes.rs index 33f3e3cc..d85bbe76 100644 --- a/sqlx-core/src/postgres/types/bytes.rs +++ b/sqlx-core/src/postgres/types/bytes.rs @@ -1,10 +1,8 @@ -use std::convert::TryInto; - use crate::decode::Decode; use crate::encode::Encode; use crate::postgres::protocol::TypeId; use crate::postgres::types::PgTypeInfo; -use crate::postgres::{PgValue, Postgres}; +use crate::postgres::{PgData, PgValue, Postgres}; use crate::types::Type; impl Type for [u8] { @@ -44,10 +42,10 @@ impl Encode for Vec { } impl<'de> Decode<'de, Postgres> for Vec { - fn decode(value: Option>) -> crate::Result { - match value.try_into()? { - PgValue::Binary(buf) => Ok(buf.to_vec()), - PgValue::Text(s) => { + fn decode(value: PgValue<'de>) -> crate::Result { + match value.try_get()? { + PgData::Binary(buf) => Ok(buf.to_vec()), + PgData::Text(s) => { // BYTEA is formatted as \x followed by hex characters hex::decode(&s[2..]).map_err(crate::Error::decode) } @@ -56,10 +54,10 @@ impl<'de> Decode<'de, Postgres> for Vec { } impl<'de> Decode<'de, Postgres> for &'de [u8] { - fn decode(value: Option>) -> crate::Result { - match value.try_into()? { - PgValue::Binary(buf) => Ok(buf), - PgValue::Text(_s) => Err(crate::Error::Decode( + fn decode(value: PgValue<'de>) -> crate::Result { + match value.try_get()? { + PgData::Binary(buf) => Ok(buf), + PgData::Text(_s) => Err(crate::Error::Decode( "unsupported decode to `&[u8]` of BYTEA in a simple query; \ use a prepared query or decode to `Vec`" .into(), diff --git a/sqlx-core/src/postgres/types/chrono.rs b/sqlx-core/src/postgres/types/chrono.rs index c1e13944..11d746de 100644 --- a/sqlx-core/src/postgres/types/chrono.rs +++ b/sqlx-core/src/postgres/types/chrono.rs @@ -7,9 +7,8 @@ use chrono::{DateTime, Duration, Local, NaiveDate, NaiveDateTime, NaiveTime, Tim use crate::decode::Decode; use crate::encode::Encode; use crate::postgres::protocol::TypeId; -use crate::postgres::row::PgValue; use crate::postgres::types::PgTypeInfo; -use crate::postgres::Postgres; +use crate::postgres::{PgData, PgValue, Postgres}; use crate::types::Type; use crate::Error; @@ -95,15 +94,15 @@ where } impl<'de> Decode<'de, Postgres> for NaiveTime { - fn decode(value: Option>) -> crate::Result { - match value.try_into()? { - PgValue::Binary(mut buf) => { + fn decode(value: PgValue<'de>) -> crate::Result { + match value.try_get()? { + PgData::Binary(mut buf) => { let micros = buf.read_i64::().map_err(Error::decode)?; Ok(NaiveTime::from_hms(0, 0, 0) + Duration::microseconds(micros)) } - PgValue::Text(s) => NaiveTime::parse_from_str(s, "%H:%M:%S%.f").map_err(Error::decode), + PgData::Text(s) => NaiveTime::parse_from_str(s, "%H:%M:%S%.f").map_err(Error::decode), } } } @@ -123,15 +122,15 @@ impl Encode for NaiveTime { } impl<'de> Decode<'de, Postgres> for NaiveDate { - fn decode(value: Option>) -> crate::Result { - match value.try_into()? { - PgValue::Binary(mut buf) => { + fn decode(value: PgValue<'de>) -> crate::Result { + match value.try_get()? { + PgData::Binary(mut buf) => { let days: i32 = buf.read_i32::().map_err(Error::decode)?; Ok(NaiveDate::from_ymd(2000, 1, 1) + Duration::days(days as i64)) } - PgValue::Text(s) => NaiveDate::parse_from_str(s, "%Y-%m-%d").map_err(Error::decode), + PgData::Text(s) => NaiveDate::parse_from_str(s, "%Y-%m-%d").map_err(Error::decode), } } } @@ -154,9 +153,9 @@ impl Encode for NaiveDate { } impl<'de> Decode<'de, Postgres> for NaiveDateTime { - fn decode(value: Option>) -> crate::Result { - match value.try_into()? { - PgValue::Binary(mut buf) => { + fn decode(value: PgValue<'de>) -> crate::Result { + match value.try_get()? { + PgData::Binary(mut buf) => { let micros = buf.read_i64::().map_err(Error::decode)?; postgres_epoch() @@ -173,7 +172,7 @@ impl<'de> Decode<'de, Postgres> for NaiveDateTime { }) } - PgValue::Text(s) => { + PgData::Text(s) => { NaiveDateTime::parse_from_str( s, if s.contains('+') { @@ -207,14 +206,14 @@ impl Encode for NaiveDateTime { } impl<'de> Decode<'de, Postgres> for DateTime { - fn decode(value: Option>) -> crate::Result { + fn decode(value: PgValue<'de>) -> crate::Result { let date_time = Decode::::decode(value)?; Ok(DateTime::from_utc(date_time, Utc)) } } impl<'de> Decode<'de, Postgres> for DateTime { - fn decode(value: Option>) -> crate::Result { + fn decode(value: PgValue<'de>) -> crate::Result { let date_time = Decode::::decode(value)?; Ok(Local.from_utc_datetime(&date_time)) } @@ -265,18 +264,18 @@ fn test_encode_time() { #[test] fn test_decode_time() { let buf = [0u8; 8]; - let time: NaiveTime = Decode::::decode(Some(PgValue::Binary(&buf))).unwrap(); + let time: NaiveTime = Decode::::decode(Some(PgData::Binary(&buf))).unwrap(); assert_eq!(time, NaiveTime::from_hms(0, 0, 0),); // half an hour let buf = (1_000_000i64 * 60 * 30).to_be_bytes(); - let time: NaiveTime = Decode::::decode(Some(PgValue::Binary(&buf))).unwrap(); + let time: NaiveTime = Decode::::decode(Some(PgData::Binary(&buf))).unwrap(); assert_eq!(time, NaiveTime::from_hms(0, 30, 0),); // 12:53:05.125305 let buf = (1_000_000i64 * 60 * 60 * 12 + 1_000_000i64 * 60 * 53 + 1_000_000i64 * 5 + 125305) .to_be_bytes(); - let time: NaiveTime = Decode::::decode(Some(PgValue::Binary(&buf))).unwrap(); + let time: NaiveTime = Decode::::decode(Some(PgData::Binary(&buf))).unwrap(); assert_eq!(time, NaiveTime::from_hms_micro(12, 53, 5, 125305),); } @@ -308,15 +307,15 @@ fn test_encode_datetime() { #[test] fn test_decode_datetime() { let buf = [0u8; 8]; - let date: NaiveDateTime = Decode::::decode(Some(PgValue::Binary(&buf))).unwrap(); + let date: NaiveDateTime = Decode::::decode(Some(PgData::Binary(&buf))).unwrap(); assert_eq!(date.to_string(), "2000-01-01 00:00:00"); let buf = 3_600_000_000i64.to_be_bytes(); - let date: NaiveDateTime = Decode::::decode(Some(PgValue::Binary(&buf))).unwrap(); + let date: NaiveDateTime = Decode::::decode(Some(PgData::Binary(&buf))).unwrap(); assert_eq!(date.to_string(), "2000-01-01 01:00:00"); let buf = 629_377_265_000_000i64.to_be_bytes(); - let date: NaiveDateTime = Decode::::decode(Some(PgValue::Binary(&buf))).unwrap(); + let date: NaiveDateTime = Decode::::decode(Some(PgData::Binary(&buf))).unwrap(); assert_eq!(date.to_string(), "2019-12-11 11:01:05"); } @@ -344,14 +343,14 @@ fn test_encode_date() { #[test] fn test_decode_date() { let buf = [0; 4]; - let date: NaiveDate = Decode::::decode(Some(PgValue::Binary(&buf))).unwrap(); + let date: NaiveDate = Decode::::decode(Some(PgData::Binary(&buf))).unwrap(); assert_eq!(date.to_string(), "2000-01-01"); let buf = 366i32.to_be_bytes(); - let date: NaiveDate = Decode::::decode(Some(PgValue::Binary(&buf))).unwrap(); + let date: NaiveDate = Decode::::decode(Some(PgData::Binary(&buf))).unwrap(); assert_eq!(date.to_string(), "2001-01-01"); let buf = 7284i32.to_be_bytes(); - let date: NaiveDate = Decode::::decode(Some(PgValue::Binary(&buf))).unwrap(); + let date: NaiveDate = Decode::::decode(Some(PgData::Binary(&buf))).unwrap(); assert_eq!(date.to_string(), "2019-12-11"); } diff --git a/sqlx-core/src/postgres/types/float.rs b/sqlx-core/src/postgres/types/float.rs index 4183e32f..35195e78 100644 --- a/sqlx-core/src/postgres/types/float.rs +++ b/sqlx-core/src/postgres/types/float.rs @@ -1,4 +1,3 @@ -use std::convert::TryInto; use std::str::FromStr; use byteorder::{NetworkEndian, ReadBytesExt}; @@ -8,7 +7,7 @@ use crate::encode::Encode; use crate::error::Error; use crate::postgres::protocol::TypeId; use crate::postgres::types::PgTypeInfo; -use crate::postgres::{PgValue, Postgres}; +use crate::postgres::{PgData, PgValue, Postgres}; use crate::types::Type; impl Type for f32 { @@ -35,14 +34,14 @@ impl Encode for f32 { } impl<'de> Decode<'de, Postgres> for f32 { - fn decode(value: Option>) -> crate::Result { - match value.try_into()? { - PgValue::Binary(mut buf) => buf + fn decode(value: PgValue<'de>) -> crate::Result { + match value.try_get()? { + PgData::Binary(mut buf) => buf .read_i32::() .map_err(Error::decode) .map(|value| f32::from_bits(value as u32)), - PgValue::Text(s) => f32::from_str(s).map_err(Error::decode), + PgData::Text(s) => f32::from_str(s).map_err(Error::decode), } } } @@ -71,14 +70,14 @@ impl Encode for f64 { } impl<'de> Decode<'de, Postgres> for f64 { - fn decode(value: Option>) -> crate::Result { - match value.try_into()? { - PgValue::Binary(mut buf) => buf + fn decode(value: PgValue<'de>) -> crate::Result { + match value.try_get()? { + PgData::Binary(mut buf) => buf .read_i64::() .map_err(Error::decode) .map(|value| f64::from_bits(value as u64)), - PgValue::Text(s) => f64::from_str(s).map_err(Error::decode), + PgData::Text(s) => f64::from_str(s).map_err(Error::decode), } } } diff --git a/sqlx-core/src/postgres/types/int.rs b/sqlx-core/src/postgres/types/int.rs index 3fb1fbe5..1e57ee7a 100644 --- a/sqlx-core/src/postgres/types/int.rs +++ b/sqlx-core/src/postgres/types/int.rs @@ -1,4 +1,3 @@ -use std::convert::TryInto; use std::str::FromStr; use byteorder::{NetworkEndian, ReadBytesExt}; @@ -7,7 +6,7 @@ use crate::decode::Decode; use crate::encode::Encode; use crate::postgres::protocol::TypeId; use crate::postgres::types::PgTypeInfo; -use crate::postgres::{PgValue, Postgres}; +use crate::postgres::{PgData, PgValue, Postgres}; use crate::types::Type; use crate::Error; @@ -35,10 +34,10 @@ impl Encode for i16 { } impl<'de> Decode<'de, Postgres> for i16 { - fn decode(value: Option>) -> crate::Result { - match value.try_into()? { - PgValue::Binary(mut buf) => buf.read_i16::().map_err(Error::decode), - PgValue::Text(s) => i16::from_str(s).map_err(Error::decode), + fn decode(value: PgValue<'de>) -> crate::Result { + match value.try_get()? { + PgData::Binary(mut buf) => buf.read_i16::().map_err(Error::decode), + PgData::Text(s) => i16::from_str(s).map_err(Error::decode), } } } @@ -67,10 +66,10 @@ impl Encode for i32 { } impl<'de> Decode<'de, Postgres> for i32 { - fn decode(value: Option>) -> crate::Result { - match value.try_into()? { - PgValue::Binary(mut buf) => buf.read_i32::().map_err(Error::decode), - PgValue::Text(s) => i32::from_str(s).map_err(Error::decode), + fn decode(value: PgValue<'de>) -> crate::Result { + match value.try_get()? { + PgData::Binary(mut buf) => buf.read_i32::().map_err(Error::decode), + PgData::Text(s) => i32::from_str(s).map_err(Error::decode), } } } @@ -99,10 +98,10 @@ impl Encode for i64 { } impl<'de> Decode<'de, Postgres> for i64 { - fn decode(value: Option>) -> crate::Result { - match value.try_into()? { - PgValue::Binary(mut buf) => buf.read_i64::().map_err(Error::decode), - PgValue::Text(s) => i64::from_str(s).map_err(Error::decode), + fn decode(value: PgValue<'de>) -> crate::Result { + match value.try_get()? { + PgData::Binary(mut buf) => buf.read_i64::().map_err(Error::decode), + PgData::Text(s) => i64::from_str(s).map_err(Error::decode), } } } diff --git a/sqlx-core/src/postgres/types/ipnetwork.rs b/sqlx-core/src/postgres/types/ipnetwork.rs index c01b8b51..aba69c8e 100644 --- a/sqlx-core/src/postgres/types/ipnetwork.rs +++ b/sqlx-core/src/postgres/types/ipnetwork.rs @@ -1,4 +1,3 @@ -use std::convert::TryInto; use std::net::{Ipv4Addr, Ipv6Addr}; use ipnetwork::{IpNetwork, Ipv4Network, Ipv6Network}; @@ -6,9 +5,9 @@ use ipnetwork::{IpNetwork, Ipv4Network, Ipv6Network}; use crate::decode::Decode; use crate::encode::Encode; use crate::postgres::protocol::TypeId; -use crate::postgres::row::PgValue; use crate::postgres::types::PgTypeInfo; -use crate::postgres::Postgres; +use crate::postgres::value::PgValue; +use crate::postgres::{PgData, Postgres}; use crate::types::Type; use crate::Error; @@ -67,10 +66,10 @@ impl Encode for IpNetwork { } impl<'de> Decode<'de, Postgres> for IpNetwork { - fn decode(value: Option>) -> crate::Result { - match value.try_into()? { - PgValue::Binary(buf) => decode(buf), - PgValue::Text(s) => s.parse().map_err(|err| crate::Error::decode(err)), + fn decode(value: PgValue<'de>) -> crate::Result { + match value.try_get()? { + PgData::Binary(buf) => decode(buf), + PgData::Text(s) => s.parse().map_err(crate::Error::decode), } } } diff --git a/sqlx-core/src/postgres/types/json.rs b/sqlx-core/src/postgres/types/json.rs index 38737ece..caf9bafd 100644 --- a/sqlx-core/src/postgres/types/json.rs +++ b/sqlx-core/src/postgres/types/json.rs @@ -26,7 +26,7 @@ impl Encode for JsonValue { } impl<'de> Decode<'de, Postgres> for JsonValue { - fn decode(value: Option>) -> crate::Result { + fn decode(value: PgValue<'de>) -> crate::Result { as Decode>::decode(value).map(|item| item.0) } } @@ -44,7 +44,7 @@ impl Encode for &'_ JsonRawValue { } impl<'de> Decode<'de, Postgres> for &'de JsonRawValue { - fn decode(value: Option>) -> crate::Result { + fn decode(value: PgValue<'de>) -> crate::Result { as Decode>::decode(value).map(|item| item.0) } } @@ -69,7 +69,7 @@ where T: 'de, T: Deserialize<'de>, { - fn decode(value: Option>) -> crate::Result { + fn decode(value: PgValue<'de>) -> crate::Result { as Decode>::decode(value).map(|item| Self(item.0)) } } diff --git a/sqlx-core/src/postgres/types/mod.rs b/sqlx-core/src/postgres/types/mod.rs index a04a07d8..2dd0fc31 100644 --- a/sqlx-core/src/postgres/types/mod.rs +++ b/sqlx-core/src/postgres/types/mod.rs @@ -281,10 +281,12 @@ impl<'de, T> Decode<'de, Postgres> for Option where T: Decode<'de, Postgres>, { - fn decode(value: Option>) -> crate::Result { - value - .map(|value| >::decode(Some(value))) - .transpose() + fn decode(value: PgValue<'de>) -> crate::Result { + Ok(if value.get().is_some() { + Some(>::decode(value)?) + } else { + None + }) } } diff --git a/sqlx-core/src/postgres/types/raw/array.rs b/sqlx-core/src/postgres/types/raw/array.rs index a6545438..a5b329b4 100644 --- a/sqlx-core/src/postgres/types/raw/array.rs +++ b/sqlx-core/src/postgres/types/raw/array.rs @@ -2,10 +2,9 @@ use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::io::{Buf, BufMut}; use crate::postgres::types::raw::sequence::PgSequenceDecoder; -use crate::postgres::{PgValue, Postgres}; +use crate::postgres::{PgData, PgValue, Postgres}; use crate::types::Type; use byteorder::BE; -use std::convert::TryInto; use std::marker::PhantomData; // https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/include/utils/array.h;h=7f7e744cb12bc872f628f90dad99dfdf074eb314;hb=master#l6 @@ -94,17 +93,17 @@ where T: for<'arr> Decode<'arr, Postgres>, T: Type, { - pub(crate) fn new(value: Option>) -> crate::Result { - let mut value = value.try_into()?; + pub(crate) fn new(value: PgValue<'de>) -> crate::Result { + let mut data = value.try_get()?; - match value { - PgValue::Binary(ref mut buf) => { + match data { + PgData::Binary(ref mut buf) => { // number of dimensions of the array let ndim = buf.get_i32::()?; if ndim == 0 { return Ok(Self { - inner: PgSequenceDecoder::new(PgValue::Binary(&[]), false), + inner: PgSequenceDecoder::new(PgData::Binary(&[]), false), phantom: PhantomData, }); } @@ -141,11 +140,11 @@ where } } - PgValue::Text(_) => {} + PgData::Text(_) => {} } Ok(Self { - inner: PgSequenceDecoder::new(value, false), + inner: PgSequenceDecoder::new(data, false), phantom: PhantomData, }) } @@ -172,7 +171,7 @@ where mod tests { use super::PgArrayDecoder; use super::PgArrayEncoder; - use crate::postgres::{PgValue, Postgres}; + use crate::postgres::{PgData, PgValue, Postgres}; const BUF_BINARY_I32: &[u8] = b"\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x17\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x02\x00\x00\x00\x04\x00\x00\x00\x03\x00\x00\x00\x04\x00\x00\x00\x04"; @@ -193,7 +192,7 @@ mod tests { #[test] fn it_decodes_text_i32() -> crate::Result { let s = "{1,152,-12412}"; - let mut decoder = PgArrayDecoder::::new(Some(PgValue::Text(s)))?; + let mut decoder = PgArrayDecoder::::new(Some(PgData::Text(s)))?; assert_eq!(decoder.decode()?, Some(1)); assert_eq!(decoder.decode()?, Some(152)); @@ -206,7 +205,7 @@ mod tests { #[test] fn it_decodes_text_str() -> crate::Result { let s = "{\"\",\"\\\"\"}"; - let mut decoder = PgArrayDecoder::::new(Some(PgValue::Text(s)))?; + let mut decoder = PgArrayDecoder::::new(Some(PgData::Text(s)))?; assert_eq!(decoder.decode()?, Some("".to_string())); assert_eq!(decoder.decode()?, Some("\"".to_string())); @@ -217,8 +216,8 @@ mod tests { #[test] fn it_decodes_binary_nulls() -> crate::Result { - let mut decoder = PgArrayDecoder::>::new(Some(PgValue::Binary( - b"\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x10\x00\x00\x00\x04\x00\x00\x00\x01\xff\xff\xff\xff\x00\x00\x00\x01\x01\xff\xff\xff\xff\x00\x00\x00\x01\x00" + let mut decoder = PgArrayDecoder::>::new(Some(PgData::Binary( + b"\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x10\x00\x00\x00\x04\x00\x00\x00\x01\xff\xff\xff\xff\x00\x00\x00\x01\x01\xff\xff\xff\xff\x00\x00\x00\x01\x00", 0, )))?; assert_eq!(decoder.decode()?, Some(None)); @@ -231,7 +230,7 @@ mod tests { #[test] fn it_decodes_binary_i32() -> crate::Result { - let mut decoder = PgArrayDecoder::::new(Some(PgValue::Binary(BUF_BINARY_I32)))?; + let mut decoder = PgArrayDecoder::::new(Some(PgData::Binary(BUF_BINARY_I32, 0)))?; let val_1 = decoder.decode()?; let val_2 = decoder.decode()?; diff --git a/sqlx-core/src/postgres/types/raw/json.rs b/sqlx-core/src/postgres/types/raw/json.rs index 3f5e4c52..705a7579 100644 --- a/sqlx-core/src/postgres/types/raw/json.rs +++ b/sqlx-core/src/postgres/types/raw/json.rs @@ -3,10 +3,9 @@ use crate::encode::Encode; use crate::io::{Buf, BufMut}; use crate::postgres::protocol::TypeId; use crate::postgres::types::PgTypeInfo; -use crate::postgres::{PgValue, Postgres}; +use crate::postgres::{PgData, PgValue, Postgres}; use crate::types::Type; use serde::{Deserialize, Serialize}; -use std::convert::TryInto; #[derive(Debug, PartialEq)] pub struct PgJson(pub T); @@ -32,10 +31,10 @@ where T: 'de, T: Deserialize<'de>, { - fn decode(value: Option>) -> crate::Result { - (match value.try_into()? { - PgValue::Text(s) => serde_json::from_str(s), - PgValue::Binary(buf) => serde_json::from_slice(buf), + fn decode(value: PgValue<'de>) -> crate::Result { + (match value.try_get()? { + PgData::Text(s) => serde_json::from_str(s), + PgData::Binary(buf) => serde_json::from_slice(buf), }) .map(PgJson) .map_err(crate::Error::decode) @@ -71,10 +70,10 @@ where T: 'de, T: Deserialize<'de>, { - fn decode(value: Option>) -> crate::Result { - (match value.try_into()? { - PgValue::Text(s) => serde_json::from_str(s), - PgValue::Binary(mut buf) => { + fn decode(value: PgValue<'de>) -> crate::Result { + (match value.try_get()? { + PgData::Text(s) => serde_json::from_str(s), + PgData::Binary(mut buf) => { let version = buf.get_u8()?; assert_eq!( diff --git a/sqlx-core/src/postgres/types/raw/numeric.rs b/sqlx-core/src/postgres/types/raw/numeric.rs index 7e8ff825..6284c96d 100644 --- a/sqlx-core/src/postgres/types/raw/numeric.rs +++ b/sqlx-core/src/postgres/types/raw/numeric.rs @@ -1,4 +1,4 @@ -use std::convert::TryInto; +use core::convert::TryInto; use byteorder::BigEndian; @@ -6,7 +6,7 @@ use crate::decode::Decode; use crate::encode::Encode; use crate::io::{Buf, BufMut}; use crate::postgres::protocol::TypeId; -use crate::postgres::{PgTypeInfo, PgValue, Postgres}; +use crate::postgres::{PgData, PgTypeInfo, PgValue, Postgres}; use crate::types::Type; use crate::Error; @@ -109,8 +109,8 @@ impl PgNumeric { /// Receiving `PgNumeric` is currently only supported for the Postgres /// binary (prepared statements) protocol. impl Decode<'_, Postgres> for PgNumeric { - fn decode(value: Option) -> crate::Result { - if let PgValue::Binary(bytes) = value.try_into()? { + fn decode(value: PgValue) -> crate::Result { + if let PgData::Binary(bytes) = value.try_get()? { Self::from_bytes(bytes) } else { Err(Error::Decode( @@ -119,6 +119,7 @@ impl Decode<'_, Postgres> for PgNumeric { } } } + /// ### Panics /// /// * If `digits.len()` overflows `i16` diff --git a/sqlx-core/src/postgres/types/raw/record.rs b/sqlx-core/src/postgres/types/raw/record.rs index de21a631..b10e257a 100644 --- a/sqlx-core/src/postgres/types/raw/record.rs +++ b/sqlx-core/src/postgres/types/raw/record.rs @@ -2,10 +2,9 @@ use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::io::Buf; use crate::postgres::types::raw::sequence::PgSequenceDecoder; -use crate::postgres::{PgValue, Postgres}; +use crate::postgres::{PgData, PgValue, Postgres}; use crate::types::Type; use byteorder::BigEndian; -use std::convert::TryInto; pub struct PgRecordEncoder<'a> { buf: &'a mut Vec, @@ -62,17 +61,17 @@ impl<'a> PgRecordEncoder<'a> { pub struct PgRecordDecoder<'de>(PgSequenceDecoder<'de>); impl<'de> PgRecordDecoder<'de> { - pub fn new(value: Option>) -> crate::Result { - let mut value: PgValue = value.try_into()?; + pub fn new(value: PgValue<'de>) -> crate::Result { + let mut data = value.try_get()?; - match value { - PgValue::Text(_) => {} - PgValue::Binary(ref mut buf) => { + match data { + PgData::Text(_) => {} + PgData::Binary(ref mut buf) => { let _expected_len = buf.get_u32::()?; } } - Ok(Self(PgSequenceDecoder::new(value, true))) + Ok(Self(PgSequenceDecoder::new(data, true))) } #[inline] @@ -119,7 +118,7 @@ fn test_decode_field() { encoder.encode(&value); let buf = buf.as_slice(); - let mut decoder = PgRecordDecoder::new(Some(PgValue::Binary(buf))).unwrap(); + let mut decoder = PgRecordDecoder::new(Some(PgData::Binary(buf, 0))).unwrap(); let value_decoded: String = decoder.decode().unwrap(); assert_eq!(value_decoded, value); diff --git a/sqlx-core/src/postgres/types/raw/sequence.rs b/sqlx-core/src/postgres/types/raw/sequence.rs index 6141fc22..652924ba 100644 --- a/sqlx-core/src/postgres/types/raw/sequence.rs +++ b/sqlx-core/src/postgres/types/raw/sequence.rs @@ -1,31 +1,31 @@ use crate::decode::Decode; use crate::io::Buf; -use crate::postgres::{PgValue, Postgres}; +use crate::postgres::{PgData, PgValue, Postgres}; use crate::types::Type; use byteorder::BigEndian; pub(crate) struct PgSequenceDecoder<'de> { - value: PgValue<'de>, + data: PgData<'de>, len: usize, mixed: bool, } impl<'de> PgSequenceDecoder<'de> { - pub(crate) fn new(mut value: PgValue<'de>, mixed: bool) -> Self { - match value { - PgValue::Binary(_) => { + pub(crate) fn new(mut data: PgData<'de>, mixed: bool) -> Self { + match data { + PgData::Binary(_) => { // assume that this has already gotten tweaked by the caller as // tuples and arrays have a very different header } - PgValue::Text(ref mut s) => { + PgData::Text(ref mut s) => { // remove the outer ( ... ) or { ... } *s = &s[1..(s.len() - 1)]; } } Self { - value, + data, mixed, len: 0, } @@ -40,8 +40,8 @@ impl<'de> PgSequenceDecoder<'de> { T: for<'seq> Decode<'seq, Postgres>, T: Type, { - match self.value { - PgValue::Binary(ref mut buf) => { + match self.data { + PgData::Binary(ref mut buf) => { if buf.is_empty() { return Ok(None); } @@ -59,13 +59,15 @@ impl<'de> PgSequenceDecoder<'de> { let len = buf.get_i32::()? as isize; let value = if len < 0 { - T::decode(None)? + // TODO: Grab the correct element OID + T::decode(PgValue::null(0))? } else { let value_buf = &buf[..(len as usize)]; *buf = &buf[(len as usize)..]; - T::decode(Some(PgValue::Binary(value_buf)))? + // TODO: Grab the correct element OID + T::decode(PgValue::bytes(0, value_buf))? }; self.len += 1; @@ -73,7 +75,7 @@ impl<'de> PgSequenceDecoder<'de> { Ok(Some(value)) } - PgValue::Text(ref mut s) => { + PgData::Text(ref mut s) => { if s.is_empty() { return Ok(None); } @@ -134,12 +136,15 @@ impl<'de> PgSequenceDecoder<'de> { }; let value = T::decode(if end == Some(0) { - None + // TODO: Grab the correct element OID + PgValue::null(0) } else if !self.mixed && value == "NULL" { // Yes, in arrays the text encoding of a NULL is just NULL - None + // TODO: Grab the correct element OID + PgValue::null(0) } else { - Some(PgValue::Text(&value)) + // TODO: Grab the correct element OID + PgValue::str(0, &*value) })?; *s = if let Some(end) = end { @@ -158,7 +163,7 @@ impl<'de> PgSequenceDecoder<'de> { impl<'de> From<&'de str> for PgSequenceDecoder<'de> { fn from(s: &'de str) -> Self { - Self::new(PgValue::Text(s), false) + Self::new(PgData::Text(s), false) } } diff --git a/sqlx-core/src/postgres/types/record.rs b/sqlx-core/src/postgres/types/record.rs index bde80ca7..3cb7b195 100644 --- a/sqlx-core/src/postgres/types/record.rs +++ b/sqlx-core/src/postgres/types/record.rs @@ -1,8 +1,8 @@ use crate::decode::Decode; use crate::postgres::protocol::TypeId; -use crate::postgres::row::PgValue; use crate::postgres::types::raw::PgRecordDecoder; use crate::postgres::types::PgTypeInfo; +use crate::postgres::value::PgValue; use crate::postgres::Postgres; use crate::types::Type; @@ -42,7 +42,7 @@ macro_rules! impl_pg_record_for_tuple { $($T: Type,)+ $($T: for<'tup> Decode<'tup, Postgres>,)+ { - fn decode(value: Option>) -> crate::Result { + fn decode(value: PgValue<'de>) -> crate::Result { let mut decoder = PgRecordDecoder::new(value)?; $(let $idx: $T = decoder.decode()?;)+ diff --git a/sqlx-core/src/postgres/types/str.rs b/sqlx-core/src/postgres/types/str.rs index 098a67c2..2d8808c7 100644 --- a/sqlx-core/src/postgres/types/str.rs +++ b/sqlx-core/src/postgres/types/str.rs @@ -1,11 +1,10 @@ -use std::convert::TryInto; use std::str::from_utf8; use crate::decode::Decode; use crate::encode::Encode; use crate::postgres::protocol::TypeId; -use crate::postgres::row::PgValue; use crate::postgres::types::PgTypeInfo; +use crate::postgres::value::{PgData, PgValue}; use crate::postgres::Postgres; use crate::types::Type; use crate::Error; @@ -67,16 +66,16 @@ impl Encode for String { } impl<'de> Decode<'de, Postgres> for String { - fn decode(buf: Option>) -> crate::Result { - <&'de str as Decode>::decode(buf).map(ToOwned::to_owned) + fn decode(value: PgValue<'de>) -> crate::Result { + <&'de str as Decode>::decode(value).map(ToOwned::to_owned) } } impl<'de> Decode<'de, Postgres> for &'de str { - fn decode(value: Option>) -> crate::Result { - match value.try_into()? { - PgValue::Binary(buf) => from_utf8(buf).map_err(Error::decode), - PgValue::Text(s) => Ok(s), + fn decode(value: PgValue<'de>) -> crate::Result { + match value.try_get()? { + PgData::Binary(buf) => from_utf8(buf).map_err(Error::decode), + PgData::Text(s) => Ok(s), } } } diff --git a/sqlx-core/src/postgres/types/time.rs b/sqlx-core/src/postgres/types/time.rs index dcc40025..ecf5fbf8 100644 --- a/sqlx-core/src/postgres/types/time.rs +++ b/sqlx-core/src/postgres/types/time.rs @@ -10,7 +10,7 @@ use crate::encode::Encode; use crate::io::Buf; use crate::postgres::protocol::TypeId; use crate::postgres::types::PgTypeInfo; -use crate::postgres::{PgValue, Postgres}; +use crate::postgres::{PgData, PgValue, Postgres}; use crate::types::Type; const POSTGRES_EPOCH: PrimitiveDateTime = date!(2000 - 1 - 1).midnight(); @@ -109,15 +109,15 @@ fn from_microseconds_since_midnight(mut microsecond: u64) -> crate::Result Decode<'de, Postgres> for Time { - fn decode(value: Option>) -> crate::Result { - match value.try_into()? { - PgValue::Binary(mut buf) => { + fn decode(value: PgValue<'de>) -> crate::Result { + match value.try_get()? { + PgData::Binary(mut buf) => { let micros: i64 = buf.get_i64::()?; from_microseconds_since_midnight(micros as u64) } - PgValue::Text(s) => { + PgData::Text(s) => { // If there are less than 9 digits after the decimal point // We need to zero-pad // TODO: Ask [time] to add a parse % for less-than-fixed-9 nanos @@ -147,15 +147,15 @@ impl Encode for Time { } impl<'de> Decode<'de, Postgres> for Date { - fn decode(value: Option>) -> crate::Result { - match value.try_into()? { - PgValue::Binary(mut buf) => { + fn decode(value: PgValue<'de>) -> crate::Result { + match value.try_get()? { + PgData::Binary(mut buf) => { let n: i32 = buf.get_i32::()?; Ok(date!(2000 - 1 - 1) + n.days()) } - PgValue::Text(s) => Date::parse(s, "%Y-%m-%d").map_err(crate::Error::decode), + PgData::Text(s) => Date::parse(s, "%Y-%m-%d").map_err(crate::Error::decode), } } } @@ -177,16 +177,16 @@ impl Encode for Date { } impl<'de> Decode<'de, Postgres> for PrimitiveDateTime { - fn decode(value: Option>) -> crate::Result { - match value.try_into()? { - PgValue::Binary(mut buf) => { + fn decode(value: PgValue<'de>) -> crate::Result { + match value.try_get()? { + PgData::Binary(mut buf) => { let n: i64 = buf.get_i64::()?; Ok(POSTGRES_EPOCH + n.microseconds()) } // TODO: Try and fix duplication between here and MySQL - PgValue::Text(s) => { + PgData::Text(s) => { // If there are less than 9 digits after the decimal point // We need to zero-pad // TODO: Ask [time] to add a parse % for less-than-fixed-9 nanos @@ -233,7 +233,7 @@ impl Encode for PrimitiveDateTime { } impl<'de> Decode<'de, Postgres> for OffsetDateTime { - fn decode(value: Option>) -> crate::Result { + fn decode(value: PgValue<'de>) -> crate::Result { let primitive: PrimitiveDateTime = Decode::::decode(value)?; Ok(primitive.assume_utc()) @@ -285,18 +285,18 @@ fn test_encode_time() { #[test] fn test_decode_time() { let buf = [0u8; 8]; - let time: Time = Decode::::decode(Some(PgValue::Binary(&buf))).unwrap(); + let time: Time = Decode::::decode(Some(PgData::Binary(&buf))).unwrap(); assert_eq!(time, time!(0:00)); // half an hour let buf = (1_000_000i64 * 60 * 30).to_be_bytes(); - let time: Time = Decode::::decode(Some(PgValue::Binary(&buf))).unwrap(); + let time: Time = Decode::::decode(Some(PgData::Binary(&buf))).unwrap(); assert_eq!(time, time!(0:30)); // 12:53:05.125305 let buf = (1_000_000i64 * 60 * 60 * 12 + 1_000_000i64 * 60 * 53 + 1_000_000i64 * 5 + 125305) .to_be_bytes(); - let time: Time = Decode::::decode(Some(PgValue::Binary(&buf))).unwrap(); + let time: Time = Decode::::decode(Some(PgData::Binary(&buf))).unwrap(); assert_eq!(time, time!(12:53:05.125305)); } @@ -325,21 +325,21 @@ fn test_encode_datetime() { #[test] fn test_decode_datetime() { let buf = [0u8; 8]; - let date: PrimitiveDateTime = Decode::::decode(Some(PgValue::Binary(&buf))).unwrap(); + let date: PrimitiveDateTime = Decode::::decode(Some(PgData::Binary(&buf))).unwrap(); assert_eq!( date, PrimitiveDateTime::new(date!(2000 - 01 - 01), time!(00:00:00)) ); let buf = 3_600_000_000i64.to_be_bytes(); - let date: PrimitiveDateTime = Decode::::decode(Some(PgValue::Binary(&buf))).unwrap(); + let date: PrimitiveDateTime = Decode::::decode(Some(PgData::Binary(&buf))).unwrap(); assert_eq!( date, PrimitiveDateTime::new(date!(2000 - 01 - 01), time!(01:00:00)) ); let buf = 629_377_265_000_000i64.to_be_bytes(); - let date: PrimitiveDateTime = Decode::::decode(Some(PgValue::Binary(&buf))).unwrap(); + let date: PrimitiveDateTime = Decode::::decode(Some(PgData::Binary(&buf))).unwrap(); assert_eq!( date, PrimitiveDateTime::new(date!(2019 - 12 - 11), time!(11:01:05)) @@ -372,21 +372,21 @@ fn test_encode_offsetdatetime() { #[test] fn test_decode_offsetdatetime() { let buf = [0u8; 8]; - let date: OffsetDateTime = Decode::::decode(Some(PgValue::Binary(&buf))).unwrap(); + let date: OffsetDateTime = Decode::::decode(Some(PgData::Binary(&buf))).unwrap(); assert_eq!( date, PrimitiveDateTime::new(date!(2000 - 01 - 01), time!(00:00:00)).assume_utc() ); let buf = 3_600_000_000i64.to_be_bytes(); - let date: OffsetDateTime = Decode::::decode(Some(PgValue::Binary(&buf))).unwrap(); + let date: OffsetDateTime = Decode::::decode(Some(PgData::Binary(&buf))).unwrap(); assert_eq!( date, PrimitiveDateTime::new(date!(2000 - 01 - 01), time!(01:00:00)).assume_utc() ); let buf = 629_377_265_000_000i64.to_be_bytes(); - let date: OffsetDateTime = Decode::::decode(Some(PgValue::Binary(&buf))).unwrap(); + let date: OffsetDateTime = Decode::::decode(Some(PgData::Binary(&buf))).unwrap(); assert_eq!( date, PrimitiveDateTime::new(date!(2019 - 12 - 11), time!(11:01:05)).assume_utc() @@ -417,14 +417,14 @@ fn test_encode_date() { #[test] fn test_decode_date() { let buf = [0; 4]; - let date: Date = Decode::::decode(Some(PgValue::Binary(&buf))).unwrap(); + let date: Date = Decode::::decode(Some(PgData::Binary(&buf))).unwrap(); assert_eq!(date, date!(2000 - 01 - 01)); let buf = 366i32.to_be_bytes(); - let date: Date = Decode::::decode(Some(PgValue::Binary(&buf))).unwrap(); + let date: Date = Decode::::decode(Some(PgData::Binary(&buf))).unwrap(); assert_eq!(date, date!(2001 - 01 - 01)); let buf = 7284i32.to_be_bytes(); - let date: Date = Decode::::decode(Some(PgValue::Binary(&buf))).unwrap(); + let date: Date = Decode::::decode(Some(PgData::Binary(&buf))).unwrap(); assert_eq!(date, date!(2019 - 12 - 11)); } diff --git a/sqlx-core/src/postgres/types/uuid.rs b/sqlx-core/src/postgres/types/uuid.rs index 0b5720c5..2a0d30c4 100644 --- a/sqlx-core/src/postgres/types/uuid.rs +++ b/sqlx-core/src/postgres/types/uuid.rs @@ -1,4 +1,3 @@ -use std::convert::TryInto; use std::str::FromStr; use uuid::Uuid; @@ -6,8 +5,8 @@ use uuid::Uuid; use crate::decode::Decode; use crate::encode::Encode; use crate::postgres::protocol::TypeId; -use crate::postgres::row::PgValue; use crate::postgres::types::PgTypeInfo; +use crate::postgres::value::{PgData, PgValue}; use crate::postgres::Postgres; use crate::types::Type; @@ -36,10 +35,10 @@ impl Encode for Uuid { } impl<'de> Decode<'de, Postgres> for Uuid { - fn decode(value: Option>) -> crate::Result { - match value.try_into()? { - PgValue::Binary(buf) => Uuid::from_slice(buf).map_err(|err| crate::Error::decode(err)), - PgValue::Text(s) => Uuid::from_str(s).map_err(|err| crate::Error::decode(err)), + fn decode(value: PgValue<'de>) -> crate::Result { + match value.try_get()? { + PgData::Binary(buf) => Uuid::from_slice(buf).map_err(crate::Error::decode), + PgData::Text(s) => Uuid::from_str(s).map_err(crate::Error::decode), } } } diff --git a/sqlx-core/src/postgres/value.rs b/sqlx-core/src/postgres/value.rs new file mode 100644 index 00000000..0878aa28 --- /dev/null +++ b/sqlx-core/src/postgres/value.rs @@ -0,0 +1,70 @@ +use crate::error::UnexpectedNullError; +use crate::postgres::{PgTypeInfo, Postgres}; +use crate::value::RawValue; +use std::str::from_utf8; + +#[derive(Debug, Copy, Clone)] +pub enum PgData<'c> { + Binary(&'c [u8]), + Text(&'c str), +} + +#[derive(Debug)] +pub struct PgValue<'c> { + type_oid: u32, + data: Option>, +} + +impl<'c> PgValue<'c> { + /// Gets the binary or text data for this value; or, `UnexpectedNullError` if this + /// is a `NULL` value. + pub(crate) fn try_get(&self) -> crate::Result> { + match self.data { + Some(data) => Ok(data), + None => Err(crate::Error::decode(UnexpectedNullError)), + } + } + + /// Gets the binary or text data for this value; or, `None` if this + /// is a `NULL` value. + #[inline] + pub fn get(&self) -> Option> { + self.data + } + + pub(crate) fn null(type_oid: u32) -> Self { + Self { + type_oid, + data: None, + } + } + + pub(crate) fn bytes(type_oid: u32, buf: &'c [u8]) -> Self { + Self { + type_oid, + data: Some(PgData::Binary(buf)), + } + } + + pub(crate) fn utf8(type_oid: u32, buf: &'c [u8]) -> crate::Result { + Ok(Self { + type_oid, + data: Some(PgData::Text(from_utf8(&buf).map_err(crate::Error::decode)?)), + }) + } + + pub(crate) fn str(type_oid: u32, s: &'c str) -> Self { + Self { + type_oid, + data: Some(PgData::Text(s)), + } + } +} + +impl<'c> RawValue<'c> for PgValue<'c> { + type Database = Postgres; + + fn type_info(&self) -> PgTypeInfo { + PgTypeInfo::with_oid(self.type_oid) + } +} diff --git a/tests/postgres-types.rs b/tests/postgres-types.rs index c8369e49..565b6c27 100644 --- a/tests/postgres-types.rs +++ b/tests/postgres-types.rs @@ -474,7 +474,7 @@ END $$; } impl<'de> Decode<'de, Postgres> for RecordEmpty { - fn decode(_value: Option>) -> sqlx::Result { + fn decode(_value: PgValue<'de>) -> sqlx::Result { Ok(RecordEmpty {}) } } @@ -506,7 +506,7 @@ END $$; } impl<'de> Decode<'de, Postgres> for Record1 { - fn decode(value: Option>) -> sqlx::Result { + fn decode(value: PgValue<'de>) -> sqlx::Result { let mut decoder = PgRecordDecoder::new(value)?; let _1 = decoder.decode()?;