diff --git a/sqlx-core/src/postgres/arguments.rs b/sqlx-core/src/postgres/arguments.rs index fe1247b6..0af724f1 100644 --- a/sqlx-core/src/postgres/arguments.rs +++ b/sqlx-core/src/postgres/arguments.rs @@ -3,16 +3,16 @@ use byteorder::{ByteOrder, NetworkEndian}; use crate::arguments::Arguments; use crate::encode::{Encode, IsNull}; use crate::io::BufMut; -use crate::postgres::Postgres; +use crate::postgres::{PgRawBuffer, PgTypeInfo, Postgres}; use crate::types::Type; #[derive(Default)] pub struct PgArguments { - // OIDs of the bind parameters - pub(super) types: Vec, + // Types of the bind parameters + pub(super) types: Vec, // Write buffer for serializing bind values - pub(super) values: Vec, + pub(super) buffer: PgRawBuffer, } impl Arguments for PgArguments { @@ -20,25 +20,24 @@ impl Arguments for PgArguments { fn reserve(&mut self, len: usize, size: usize) { self.types.reserve(len); - self.values.reserve(size); + self.buffer.reserve(size); } fn add(&mut self, value: T) where - T: Type, - T: Encode, + T: Type + Encode, { // TODO: When/if we receive types that do _not_ support BINARY, we need to check here // TODO: There is no need to be explicit unless we are expecting mixed BINARY / TEXT - self.types.push(>::type_info().id.0); + self.types.push(>::type_info()); - let pos = self.values.len(); + // Reserves space for the length of the value + let pos = self.buffer.len(); + self.buffer.put_i32::(0); - self.values.put_i32::(0); - - let len = if let IsNull::No = value.encode_nullable(&mut self.values) { - (self.values.len() - pos - 4) as i32 + let len = if let IsNull::No = value.encode_nullable(&mut self.buffer) { + (self.buffer.len() - pos - 4) as i32 } else { // Write a -1 for the len to indicate NULL // TODO: It is illegal for [encode] to write any data @@ -47,6 +46,6 @@ impl Arguments for PgArguments { }; // Write-back the len to the beginning of this frame (not including the len of len) - NetworkEndian::write_i32(&mut self.values[pos..], len as i32); + NetworkEndian::write_i32(&mut self.buffer[pos..], len as i32); } } diff --git a/sqlx-core/src/postgres/buffer.rs b/sqlx-core/src/postgres/buffer.rs new file mode 100644 index 00000000..88b563fd --- /dev/null +++ b/sqlx-core/src/postgres/buffer.rs @@ -0,0 +1,57 @@ +use crate::postgres::type_info::SharedStr; +use crate::postgres::PgConnection; +use byteorder::{ByteOrder, NetworkEndian}; +use core::ops::{Deref, DerefMut}; + +#[derive(Debug, Default, PartialEq)] +pub struct PgRawBuffer { + inner: Vec, + + // Whenever an `Encode` impl encounters a `PgTypeInfo` object that does not have an OID + // It pushes a "hole" that must be patched later + // The hole is a `usize` offset into the buffer with the type name that should be resolved + // This is done for Records and Arrays as the OID is needed well before we are in an async + // function and can just ask postgres + type_holes: Vec<(usize, SharedStr)>, +} + +impl PgRawBuffer { + // Extends the inner buffer by enough space to have an OID + // Remembers where the OID goes and type name for the OID + pub(crate) fn push_type_hole(&mut self, type_name: &SharedStr) { + let offset = self.len(); + + self.extend_from_slice(&0_u32.to_be_bytes()); + self.type_holes.push((offset, type_name.clone())); + } + + // Patch all remembered type holes + // This should only go out and ask postgres if we have not seen the type name yet + pub(crate) async fn patch_type_holes( + &mut self, + connection: &mut PgConnection, + ) -> crate::Result<()> { + for (offset, name) in &self.type_holes { + let oid = connection.get_type_id_by_name(&*name).await?; + NetworkEndian::write_u32(&mut self.inner[*offset..], oid); + } + + Ok(()) + } +} + +impl Deref for PgRawBuffer { + type Target = Vec; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl DerefMut for PgRawBuffer { + #[inline] + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} diff --git a/sqlx-core/src/postgres/connection.rs b/sqlx-core/src/postgres/connection.rs index 8d9aabed..09e414b3 100644 --- a/sqlx-core/src/postgres/connection.rs +++ b/sqlx-core/src/postgres/connection.rs @@ -15,6 +15,7 @@ use crate::postgres::protocol::{ }; use crate::postgres::row::Statement; use crate::postgres::stream::PgStream; +use crate::postgres::type_info::SharedStr; use crate::postgres::{sasl, tls}; use crate::url::Url; @@ -89,6 +90,12 @@ pub struct PgConnection { // cache statement ID -> statement description pub(super) cache_statement: HashMap>, + // cache type name -> type OID + pub(super) cache_type_oid: HashMap, + + // cache type OID -> type name + pub(super) cache_type_name: HashMap, + // Work buffer for the value ranges of the current row // This is used as the backing memory for each Row's value indexes pub(super) current_row_values: Vec>>, @@ -251,6 +258,8 @@ impl PgConnection { current_row_values: Vec::with_capacity(10), next_statement_id: 1, is_ready: true, + cache_type_oid: HashMap::new(), + cache_type_name: HashMap::new(), cache_statement_id: HashMap::with_capacity(10), cache_statement: HashMap::with_capacity(10), process_id: key_data.process_id, diff --git a/sqlx-core/src/postgres/cursor.rs b/sqlx-core/src/postgres/cursor.rs index b737652b..13f653ef 100644 --- a/sqlx-core/src/postgres/cursor.rs +++ b/sqlx-core/src/postgres/cursor.rs @@ -53,7 +53,7 @@ impl<'c, 'q> Cursor<'c, 'q> for PgCursor<'c, 'q> { } } -fn parse_row_description(rd: RowDescription) -> Statement { +fn parse_row_description(conn: &mut PgConnection, rd: RowDescription) -> Statement { let mut names = HashMap::new(); let mut columns = Vec::new(); @@ -65,8 +65,10 @@ fn parse_row_description(rd: RowDescription) -> Statement { names.insert(name.clone(), index); } + let type_info = conn.get_type_info_by_oid(field.type_id.0); + columns.push(Column { - type_id: field.type_id, + type_info, format: field.type_format, }); } @@ -100,7 +102,11 @@ async fn expect_desc(conn: &mut PgConnection) -> crate::Result { } }; - Ok(description.map(parse_row_description).unwrap_or_default()) + if let Some(description) = description { + Ok(parse_row_description(conn, description)) + } else { + Ok(Statement::default()) + } } // A form of describe that uses the statement cache @@ -159,7 +165,7 @@ async fn next<'a, 'c: 'a, 'q: 'a>( Message::RowDescription => { let rd = RowDescription::read(conn.stream.buffer())?; - cursor.statement = Arc::new(parse_row_description(rd)); + cursor.statement = Arc::new(parse_row_description(conn, rd)); } Message::DataRow => { diff --git a/sqlx-core/src/postgres/database.rs b/sqlx-core/src/postgres/database.rs index 5deabcba..585cbe45 100644 --- a/sqlx-core/src/postgres/database.rs +++ b/sqlx-core/src/postgres/database.rs @@ -2,7 +2,9 @@ use crate::cursor::HasCursor; use crate::database::Database; -use crate::postgres::{PgArguments, PgConnection, PgCursor, PgError, PgRow, PgTypeInfo, PgValue}; +use crate::postgres::{ + PgArguments, PgConnection, PgCursor, PgError, PgRawBuffer, PgRow, PgTypeInfo, PgValue, +}; use crate::row::HasRow; use crate::value::HasRawValue; @@ -19,7 +21,7 @@ impl Database for Postgres { type TableId = u32; - type RawBuffer = Vec; + type RawBuffer = PgRawBuffer; type Error = PgError; } diff --git a/sqlx-core/src/postgres/executor.rs b/sqlx-core/src/postgres/executor.rs index 1d388482..31fabca5 100644 --- a/sqlx-core/src/postgres/executor.rs +++ b/sqlx-core/src/postgres/executor.rs @@ -12,8 +12,12 @@ use crate::postgres::protocol::{ self, CommandComplete, Field, Message, ParameterDescription, ReadyForQuery, RowDescription, StatementId, TypeFormat, TypeId, }; -use crate::postgres::types::SharedStr; -use crate::postgres::{PgArguments, PgConnection, PgCursor, PgRow, PgTypeInfo, Postgres}; +use crate::postgres::type_info::SharedStr; +use crate::postgres::types::try_resolve_type_name; +use crate::postgres::{ + PgArguments, PgConnection, PgCursor, PgQueryAs, PgRow, PgTypeInfo, Postgres, +}; +use crate::query_as::query_as; use crate::row::Row; impl PgConnection { @@ -21,23 +25,40 @@ impl PgConnection { self.stream.write(protocol::Query(query)); } - pub(crate) fn write_prepare(&mut self, query: &str, args: &PgArguments) -> StatementId { + pub(crate) async fn write_prepare( + &mut self, + query: &str, + args: &PgArguments, + ) -> crate::Result { if let Some(&id) = self.cache_statement_id.get(query) { - id + Ok(id) } else { let id = StatementId(self.next_statement_id); self.next_statement_id += 1; + // Build a list of type OIDs from the type info array provided by PgArguments + // This may need to query Postgres for an OID of a user-defined type + + let mut types = Vec::with_capacity(args.types.len()); + + for ty in &args.types { + types.push(if let Some(oid) = ty.id { + oid.0 + } else { + self.get_type_id_by_name(&*ty.name).await? + }); + } + self.stream.write(protocol::Parse { statement: id, + param_types: &*types, query, - param_types: &*args.types, }); self.cache_statement_id.insert(query.into(), id); - id + Ok(id) } } @@ -45,15 +66,24 @@ impl PgConnection { self.stream.write(d); } - pub(crate) fn write_bind(&mut self, portal: &str, statement: StatementId, args: &PgArguments) { + pub(crate) async fn write_bind( + &mut self, + portal: &str, + statement: StatementId, + args: &mut PgArguments, + ) -> crate::Result<()> { + args.buffer.patch_type_holes(self).await?; + self.stream.write(protocol::Bind { portal, statement, formats: &[TypeFormat::Binary], values_len: args.types.len() as i16, - values: &*args.values, + values: &*args.buffer, result_formats: &[TypeFormat::Binary], }); + + Ok(()) } pub(crate) fn write_execute(&mut self, portal: &str, limit: i32) { @@ -95,14 +125,14 @@ impl PgConnection { query: &str, arguments: Option, ) -> crate::Result> { - let statement = if let Some(arguments) = arguments { + let statement = if let Some(mut arguments) = arguments { // Check the statement cache for a statement ID that matches the given query // If it doesn't exist, we generate a new statement ID and write out [Parse] to the // connection command buffer - let statement = self.write_prepare(query, &arguments); + let statement = self.write_prepare(query, &arguments).await?; // Next, [Bind] attaches the arguments to the statement and creates a named portal - self.write_bind("", statement, &arguments); + self.write_bind("", statement, &mut arguments).await?; // Next, [Describe] will return the expected result columns and types // Conditionally run [Describe] only if the results have not been cached @@ -142,7 +172,7 @@ impl PgConnection { ) -> crate::Result> { self.is_ready = false; - let statement = self.write_prepare(query, &Default::default()); + let statement = self.write_prepare(query, &Default::default()).await?; self.write_describe(protocol::Describe::Statement(statement)); self.write_sync(); @@ -209,6 +239,42 @@ impl PgConnection { }) } + pub(crate) async fn get_type_id_by_name(&mut self, name: &str) -> crate::Result { + if let Some(oid) = self.cache_type_oid.get(name) { + return Ok(*oid); + } + + // language=SQL + let (oid,): (u32,) = query_as( + " +SElECT oid FROM pg_catalog.pg_type WHERE typname = $1 + ", + ) + .bind(name) + .fetch_one(&mut *self) + .await?; + + let shared = SharedStr::from(name.to_owned()); + + self.cache_type_oid.insert(shared.clone(), oid); + self.cache_type_name.insert(oid, shared.clone()); + + Ok(oid) + } + + pub(crate) fn get_type_info_by_oid(&mut self, oid: u32) -> PgTypeInfo { + if let Some(name) = try_resolve_type_name(oid) { + return PgTypeInfo::new(TypeId(oid), name); + } + + if let Some(name) = self.cache_type_name.get(&oid) { + return PgTypeInfo::new(TypeId(oid), name); + } + + // NOTE: The name isn't too important for the decode lifecycle + return PgTypeInfo::new(TypeId(oid), ""); + } + async fn get_type_names( &mut self, ids: impl IntoIterator, diff --git a/sqlx-core/src/postgres/mod.rs b/sqlx-core/src/postgres/mod.rs index dab25ab8..e581a173 100644 --- a/sqlx-core/src/postgres/mod.rs +++ b/sqlx-core/src/postgres/mod.rs @@ -1,16 +1,18 @@ //! **Postgres** database and connection types. pub use arguments::PgArguments; +pub use buffer::PgRawBuffer; pub use connection::PgConnection; pub use cursor::PgCursor; pub use database::Postgres; pub use error::PgError; pub use listen::{PgListener, PgNotification}; pub use row::PgRow; -pub use types::PgTypeInfo; +pub use type_info::PgTypeInfo; pub use value::{PgData, PgValue}; mod arguments; +mod buffer; mod connection; mod cursor; mod database; @@ -22,6 +24,7 @@ mod row; mod sasl; mod stream; mod tls; +mod type_info; pub mod types; mod value; diff --git a/sqlx-core/src/postgres/protocol/type_id.rs b/sqlx-core/src/postgres/protocol/type_id.rs index be3655a9..6da81511 100644 --- a/sqlx-core/src/postgres/protocol/type_id.rs +++ b/sqlx-core/src/postgres/protocol/type_id.rs @@ -1,3 +1,4 @@ +use crate::postgres::types::try_resolve_type_name; use std::fmt::{self, Display}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -66,7 +67,7 @@ impl TypeId { pub(crate) const ARRAY_BPCHAR: TypeId = TypeId(1014); pub(crate) const ARRAY_NAME: TypeId = TypeId(1003); - pub(crate) const ARRAY_NUMERIC: TypeId = TypeId(1700); + pub(crate) const ARRAY_NUMERIC: TypeId = TypeId(1231); pub(crate) const ARRAY_DATE: TypeId = TypeId(1182); pub(crate) const ARRAY_TIME: TypeId = TypeId(1183); @@ -84,80 +85,19 @@ impl TypeId { pub(crate) const JSON: TypeId = TypeId(114); pub(crate) const JSONB: TypeId = TypeId(3802); + + // Records + + pub(crate) const RECORD: TypeId = TypeId(2249); + pub(crate) const ARRAY_RECORD: TypeId = TypeId(2287); } impl Display for TypeId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match *self { - TypeId::BOOL => f.write_str("BOOL"), - - TypeId::CHAR => f.write_str("\"CHAR\""), - - TypeId::INT2 => f.write_str("INT2"), - TypeId::INT4 => f.write_str("INT4"), - TypeId::INT8 => f.write_str("INT8"), - - TypeId::OID => f.write_str("OID"), - - TypeId::FLOAT4 => f.write_str("FLOAT4"), - TypeId::FLOAT8 => f.write_str("FLOAT8"), - - TypeId::NUMERIC => f.write_str("NUMERIC"), - - TypeId::TEXT => f.write_str("TEXT"), - TypeId::VARCHAR => f.write_str("VARCHAR"), - TypeId::BPCHAR => f.write_str("BPCHAR"), - TypeId::UNKNOWN => f.write_str("UNKNOWN"), - TypeId::NAME => f.write_str("NAME"), - - TypeId::DATE => f.write_str("DATE"), - TypeId::TIME => f.write_str("TIME"), - TypeId::TIMESTAMP => f.write_str("TIMESTAMP"), - TypeId::TIMESTAMPTZ => f.write_str("TIMESTAMPTZ"), - - TypeId::BYTEA => f.write_str("BYTEA"), - - TypeId::UUID => f.write_str("UUID"), - - TypeId::CIDR => f.write_str("CIDR"), - TypeId::INET => f.write_str("INET"), - - TypeId::ARRAY_BOOL => f.write_str("BOOL[]"), - - TypeId::ARRAY_CHAR => f.write_str("\"CHAR\"[]"), - - TypeId::ARRAY_INT2 => f.write_str("INT2[]"), - TypeId::ARRAY_INT4 => f.write_str("INT4[]"), - TypeId::ARRAY_INT8 => f.write_str("INT8[]"), - - TypeId::ARRAY_OID => f.write_str("OID[]"), - - TypeId::ARRAY_FLOAT4 => f.write_str("FLOAT4[]"), - TypeId::ARRAY_FLOAT8 => f.write_str("FLOAT8[]"), - - TypeId::ARRAY_TEXT => f.write_str("TEXT[]"), - TypeId::ARRAY_VARCHAR => f.write_str("VARCHAR[]"), - TypeId::ARRAY_BPCHAR => f.write_str("BPCHAR[]"), - TypeId::ARRAY_NAME => f.write_str("NAME[]"), - - TypeId::ARRAY_NUMERIC => f.write_str("NUMERIC[]"), - - TypeId::ARRAY_DATE => f.write_str("DATE[]"), - TypeId::ARRAY_TIME => f.write_str("TIME[]"), - TypeId::ARRAY_TIMESTAMP => f.write_str("TIMESTAMP[]"), - TypeId::ARRAY_TIMESTAMPTZ => f.write_str("TIMESTAMPTZ[]"), - - TypeId::ARRAY_BYTEA => f.write_str("BYTEA[]"), - - TypeId::ARRAY_UUID => f.write_str("UUID[]"), - - TypeId::ARRAY_CIDR => f.write_str("CIDR[]"), - TypeId::ARRAY_INET => f.write_str("INET[]"), - - TypeId::JSON => f.write_str("JSON"), - TypeId::JSONB => f.write_str("JSONB"), - - _ => write!(f, "<{}>", self.0), + if let Some(name) = try_resolve_type_name(self.0) { + f.write_str(name) + } else { + write!(f, "<{}>", self.0) } } } diff --git a/sqlx-core/src/postgres/row.rs b/sqlx-core/src/postgres/row.rs index 1de10e1c..345111ff 100644 --- a/sqlx-core/src/postgres/row.rs +++ b/sqlx-core/src/postgres/row.rs @@ -1,9 +1,9 @@ use std::collections::HashMap; use std::sync::Arc; -use crate::postgres::protocol::{DataRow, TypeFormat, TypeId}; +use crate::postgres::protocol::{DataRow, TypeFormat}; use crate::postgres::value::PgValue; -use crate::postgres::Postgres; +use crate::postgres::{PgTypeInfo, Postgres}; use crate::row::{ColumnIndex, Row}; // A statement has 0 or more columns being returned from the database @@ -11,7 +11,7 @@ use crate::row::{ColumnIndex, Row}; // For simple (unprepared) queries, format will always be text // For prepared queries, format will _almost_ always be binary pub(crate) struct Column { - pub(crate) type_id: TypeId, + pub(crate) type_info: PgTypeInfo, pub(crate) format: TypeFormat, } @@ -53,9 +53,9 @@ impl<'c> Row<'c> for PgRow<'c> { let column = &self.statement.columns[index]; let buffer = self.data.get(index); let value = match (column.format, buffer) { - (_, None) => PgValue::null(column.type_id), - (TypeFormat::Binary, Some(buf)) => PgValue::bytes(column.type_id, buf), - (TypeFormat::Text, Some(buf)) => PgValue::utf8(column.type_id, buf)?, + (_, None) => PgValue::null(), + (TypeFormat::Binary, Some(buf)) => PgValue::bytes(column.type_info.clone(), buf), + (TypeFormat::Text, Some(buf)) => PgValue::utf8(column.type_info.clone(), buf)?, }; Ok(value) diff --git a/sqlx-core/src/postgres/type_info.rs b/sqlx-core/src/postgres/type_info.rs new file mode 100644 index 00000000..d3e08e2c --- /dev/null +++ b/sqlx-core/src/postgres/type_info.rs @@ -0,0 +1,184 @@ +use crate::postgres::protocol::TypeId; +use crate::types::TypeInfo; +use std::borrow::Borrow; +use std::fmt; +use std::fmt::Display; +use std::ops::Deref; +use std::sync::Arc; + +/// Type information for a Postgres SQL type. +#[derive(Debug, Clone)] +pub struct PgTypeInfo { + pub(crate) id: Option, + pub(crate) name: SharedStr, +} + +impl PgTypeInfo { + pub(crate) fn new(id: TypeId, name: impl Into) -> Self { + Self { + id: Some(id), + name: name.into(), + } + } + + /// Create a `PgTypeInfo` from a type name. + /// + /// The OID for the type will be fetched from Postgres on bind or decode of + /// a value of this type. The fetched OID will be cached per-connection. + pub const fn with_name(name: &'static str) -> Self { + Self { + id: None, + name: SharedStr::Static(name), + } + } + + #[doc(hidden)] + pub fn type_feature_gate(&self) -> Option<&'static str> { + match self.id? { + TypeId::DATE | TypeId::TIME | TypeId::TIMESTAMP | TypeId::TIMESTAMPTZ => Some("chrono"), + TypeId::UUID => Some("uuid"), + TypeId::JSON | TypeId::JSONB => Some("json"), + // we can support decoding `PgNumeric` but it's decidedly less useful to the layman + TypeId::NUMERIC => Some("bigdecimal"), + TypeId::CIDR | TypeId::INET => Some("ipnetwork"), + + _ => None, + } + } +} + +impl Display for PgTypeInfo { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.name) + } +} + +impl PartialEq for PgTypeInfo { + fn eq(&self, other: &PgTypeInfo) -> bool { + // Postgres is strongly typed (mostly) so the rules that make sense here are equivalent + // to the rules that make sense in [compatible] + self.compatible(other) + } +} + +impl TypeInfo for PgTypeInfo { + fn compatible(&self, other: &Self) -> bool { + if let (Some(self_id), Some(other_id)) = (self.id, other.id) { + return match (self_id, other_id) { + (TypeId::CIDR, TypeId::INET) + | (TypeId::INET, TypeId::CIDR) + | (TypeId::ARRAY_CIDR, TypeId::ARRAY_INET) + | (TypeId::ARRAY_INET, TypeId::ARRAY_CIDR) => true, + + // the following text-like types are compatible + (TypeId::VARCHAR, other) + | (TypeId::TEXT, other) + | (TypeId::BPCHAR, other) + | (TypeId::NAME, other) + | (TypeId::UNKNOWN, other) + if match other { + TypeId::VARCHAR + | TypeId::TEXT + | TypeId::BPCHAR + | TypeId::NAME + | TypeId::UNKNOWN => true, + _ => false, + } => + { + true + } + + // the following text-like array types are compatible + (TypeId::ARRAY_VARCHAR, other) + | (TypeId::ARRAY_TEXT, other) + | (TypeId::ARRAY_BPCHAR, other) + | (TypeId::ARRAY_NAME, other) + if match other { + TypeId::ARRAY_VARCHAR + | TypeId::ARRAY_TEXT + | TypeId::ARRAY_BPCHAR + | TypeId::ARRAY_NAME => true, + _ => false, + } => + { + true + } + + // JSON <=> JSONB + (TypeId::JSON, other) | (TypeId::JSONB, other) + if match other { + TypeId::JSON | TypeId::JSONB => true, + _ => false, + } => + { + true + } + + _ => self_id.0 == other_id.0, + }; + } + + // If the type names match, the types are equivalent (and compatible) + // If the type names are the empty string, they are invalid type names + + if (&*self.name == &*other.name) && !self.name.is_empty() { + return true; + } + + // TODO: More efficient way to do case insensitive comparison + if !self.name.is_empty() && (&*self.name.to_lowercase() == &*other.name.to_lowercase()) { + return true; + } + + false + } +} + +/// Copy of `Cow` but for strings; clones guaranteed to be cheap. +#[derive(Clone, Debug, PartialEq, Hash, Eq)] +pub(crate) enum SharedStr { + Static(&'static str), + Arc(Arc), +} + +impl Deref for SharedStr { + type Target = str; + + fn deref(&self) -> &str { + match self { + SharedStr::Static(s) => s, + SharedStr::Arc(s) => s, + } + } +} + +impl Borrow for SharedStr { + fn borrow(&self) -> &str { + &**self + } +} + +impl<'a> From<&'a SharedStr> for SharedStr { + fn from(s: &'a SharedStr) -> Self { + s.clone() + } +} + +impl From<&'static str> for SharedStr { + fn from(s: &'static str) -> Self { + SharedStr::Static(s) + } +} + +impl From for SharedStr { + #[inline] + fn from(s: String) -> Self { + SharedStr::Arc(s.into()) + } +} + +impl fmt::Display for SharedStr { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + fmt.pad(self) + } +} diff --git a/sqlx-core/src/postgres/types/array.rs b/sqlx-core/src/postgres/types/array.rs index 2dde4c83..45f71697 100644 --- a/sqlx-core/src/postgres/types/array.rs +++ b/sqlx-core/src/postgres/types/array.rs @@ -3,9 +3,8 @@ use crate::database::Database; use crate::decode::Decode; use crate::encode::Encode; -use crate::postgres::database::Postgres; use crate::postgres::types::raw::{PgArrayDecoder, PgArrayEncoder}; -use crate::postgres::PgValue; +use crate::postgres::{PgRawBuffer, PgValue, Postgres}; use crate::types::Type; impl Encode for [T] @@ -13,7 +12,7 @@ where T: Encode, T: Type, { - fn encode(&self, buf: &mut Vec) { + fn encode(&self, buf: &mut PgRawBuffer) { let mut encoder = PgArrayEncoder::new(buf); for item in self { @@ -29,7 +28,7 @@ where T: Encode, T: Type, { - fn encode(&self, buf: &mut Vec) { + fn encode(&self, buf: &mut PgRawBuffer) { self.as_slice().encode(buf) } } diff --git a/sqlx-core/src/postgres/types/bigdecimal.rs b/sqlx-core/src/postgres/types/bigdecimal.rs index b8ff25a0..a1774689 100644 --- a/sqlx-core/src/postgres/types/bigdecimal.rs +++ b/sqlx-core/src/postgres/types/bigdecimal.rs @@ -6,7 +6,7 @@ use num_bigint::{BigInt, Sign}; use crate::decode::Decode; use crate::encode::Encode; -use crate::postgres::{PgData, PgTypeInfo, PgValue, Postgres}; +use crate::postgres::{PgData, PgRawBuffer, PgTypeInfo, PgValue, Postgres}; use crate::types::Type; use super::raw::{PgNumeric, PgNumericSign}; @@ -135,7 +135,7 @@ impl TryFrom for BigDecimal { /// ### Panics /// If this `BigDecimal` cannot be represented by [PgNumeric]. impl Encode for BigDecimal { - fn encode(&self, buf: &mut Vec) { + fn encode(&self, buf: &mut PgRawBuffer) { // this copy is unfortunately necessary because we'd have to use `.to_bigint_and_exponent()` // otherwise which does the exact same thing, so for the explicit impl might as well allow // the user to skip one of the copies if they have an owned value diff --git a/sqlx-core/src/postgres/types/bool.rs b/sqlx-core/src/postgres/types/bool.rs index 3658a169..6fa649e3 100644 --- a/sqlx-core/src/postgres/types/bool.rs +++ b/sqlx-core/src/postgres/types/bool.rs @@ -1,7 +1,7 @@ use crate::decode::Decode; use crate::encode::Encode; use crate::postgres::protocol::TypeId; -use crate::postgres::{PgData, PgTypeInfo, PgValue, Postgres}; +use crate::postgres::{PgData, PgRawBuffer, PgTypeInfo, PgValue, Postgres}; use crate::types::Type; impl Type for bool { @@ -22,7 +22,7 @@ impl Type for Vec { } impl Encode for bool { - fn encode(&self, buf: &mut Vec) { + fn encode(&self, buf: &mut PgRawBuffer) { buf.push(*self as u8); } } diff --git a/sqlx-core/src/postgres/types/bytes.rs b/sqlx-core/src/postgres/types/bytes.rs index b112c164..5a3885d4 100644 --- a/sqlx-core/src/postgres/types/bytes.rs +++ b/sqlx-core/src/postgres/types/bytes.rs @@ -1,8 +1,7 @@ use crate::decode::Decode; use crate::encode::Encode; use crate::postgres::protocol::TypeId; -use crate::postgres::types::PgTypeInfo; -use crate::postgres::{PgData, PgValue, Postgres}; +use crate::postgres::{PgData, PgRawBuffer, PgTypeInfo, PgValue, Postgres}; use crate::types::Type; impl Type for [u8] { @@ -30,13 +29,13 @@ impl Type for Vec { } impl Encode for [u8] { - fn encode(&self, buf: &mut Vec) { + fn encode(&self, buf: &mut PgRawBuffer) { buf.extend_from_slice(self); } } impl Encode for Vec { - fn encode(&self, buf: &mut Vec) { + fn encode(&self, buf: &mut PgRawBuffer) { <[u8] as Encode>::encode(self, buf); } } diff --git a/sqlx-core/src/postgres/types/chrono.rs b/sqlx-core/src/postgres/types/chrono.rs index 94855e7f..824d0a12 100644 --- a/sqlx-core/src/postgres/types/chrono.rs +++ b/sqlx-core/src/postgres/types/chrono.rs @@ -7,8 +7,7 @@ use chrono::{DateTime, Duration, Local, NaiveDate, NaiveDateTime, NaiveTime, Tim use crate::decode::Decode; use crate::encode::Encode; use crate::postgres::protocol::TypeId; -use crate::postgres::types::PgTypeInfo; -use crate::postgres::{PgData, PgValue, Postgres}; +use crate::postgres::{PgData, PgRawBuffer, PgTypeInfo, PgValue, Postgres}; use crate::types::Type; use crate::Error; @@ -108,7 +107,7 @@ impl<'de> Decode<'de, Postgres> for NaiveTime { } impl Encode for NaiveTime { - fn encode(&self, buf: &mut Vec) { + fn encode(&self, buf: &mut PgRawBuffer) { let micros = (*self - NaiveTime::from_hms(0, 0, 0)) .num_microseconds() .expect("shouldn't overflow"); @@ -136,7 +135,7 @@ impl<'de> Decode<'de, Postgres> for NaiveDate { } impl Encode for NaiveDate { - fn encode(&self, buf: &mut Vec) { + fn encode(&self, buf: &mut PgRawBuffer) { let days: i32 = self .signed_duration_since(NaiveDate::from_ymd(2000, 1, 1)) .num_days() @@ -191,7 +190,7 @@ impl<'de> Decode<'de, Postgres> for NaiveDateTime { } impl Encode for NaiveDateTime { - fn encode(&self, buf: &mut Vec) { + fn encode(&self, buf: &mut PgRawBuffer) { let micros = self .signed_duration_since(postgres_epoch().naive_utc()) .num_microseconds() @@ -223,7 +222,7 @@ impl Encode for DateTime where Tz::Offset: Copy, { - fn encode(&self, buf: &mut Vec) { + fn encode(&self, buf: &mut PgRawBuffer) { Encode::::encode(&self.naive_utc(), buf); } @@ -238,60 +237,60 @@ fn postgres_epoch() -> DateTime { #[test] fn test_encode_time() { - let mut buf = Vec::new(); + let mut buf = PgRawBuffer::default(); Encode::::encode(&NaiveTime::from_hms(0, 0, 0), &mut buf); - assert_eq!(buf, [0; 8]); + assert_eq!(&**buf, [0; 8]); buf.clear(); // one second Encode::::encode(&NaiveTime::from_hms(0, 0, 1), &mut buf); - assert_eq!(buf, 1_000_000i64.to_be_bytes()); + assert_eq!(&**buf, 1_000_000i64.to_be_bytes()); buf.clear(); // two hours Encode::::encode(&NaiveTime::from_hms(2, 0, 0), &mut buf); let expected = 1_000_000i64 * 60 * 60 * 2; - assert_eq!(buf, expected.to_be_bytes()); + assert_eq!(&**buf, expected.to_be_bytes()); buf.clear(); // 3:14:15.000001 Encode::::encode(&NaiveTime::from_hms_micro(3, 14, 15, 1), &mut buf); let expected = 1_000_000i64 * 60 * 60 * 3 + 1_000_000i64 * 60 * 14 + 1_000_000i64 * 15 + 1; - assert_eq!(buf, expected.to_be_bytes()); + assert_eq!(&**buf, expected.to_be_bytes()); buf.clear(); } #[test] fn test_decode_time() { let buf = [0u8; 8]; - let time: NaiveTime = Decode::::decode(PgValue::bytes(TypeId(0), &buf)).unwrap(); + let time: NaiveTime = Decode::::decode(PgValue::from_bytes(&buf)).unwrap(); assert_eq!(time, NaiveTime::from_hms(0, 0, 0),); // half an hour let buf = (1_000_000i64 * 60 * 30).to_be_bytes(); - let time: NaiveTime = Decode::::decode(PgValue::bytes(TypeId(0), &buf)).unwrap(); + let time: NaiveTime = Decode::::decode(PgValue::from_bytes(&buf)).unwrap(); assert_eq!(time, NaiveTime::from_hms(0, 30, 0),); // 12:53:05.125305 let buf = (1_000_000i64 * 60 * 60 * 12 + 1_000_000i64 * 60 * 53 + 1_000_000i64 * 5 + 125305) .to_be_bytes(); - let time: NaiveTime = Decode::::decode(PgValue::bytes(TypeId(0), &buf)).unwrap(); + let time: NaiveTime = Decode::::decode(PgValue::from_bytes(&buf)).unwrap(); assert_eq!(time, NaiveTime::from_hms_micro(12, 53, 5, 125305),); } #[test] fn test_encode_datetime() { - let mut buf = Vec::new(); + let mut buf = PgRawBuffer::default(); let date = postgres_epoch(); Encode::::encode(&date, &mut buf); - assert_eq!(buf, [0; 8]); + assert_eq!(&**buf, [0; 8]); buf.clear(); // one hour past epoch let date2 = postgres_epoch() + Duration::hours(1); Encode::::encode(&date2, &mut buf); - assert_eq!(buf, 3_600_000_000i64.to_be_bytes()); + assert_eq!(&**buf, 3_600_000_000i64.to_be_bytes()); buf.clear(); // some random date @@ -300,57 +299,57 @@ fn test_encode_datetime() { .num_microseconds() .unwrap()); Encode::::encode(&date3, &mut buf); - assert_eq!(buf, expected.to_be_bytes()); + assert_eq!(&**buf, expected.to_be_bytes()); buf.clear(); } #[test] fn test_decode_datetime() { let buf = [0u8; 8]; - let date: NaiveDateTime = Decode::::decode(PgValue::bytes(TypeId(0), &buf)).unwrap(); + let date: NaiveDateTime = Decode::::decode(PgValue::from_bytes(&buf)).unwrap(); assert_eq!(date.to_string(), "2000-01-01 00:00:00"); let buf = 3_600_000_000i64.to_be_bytes(); - let date: NaiveDateTime = Decode::::decode(PgValue::bytes(TypeId(0), &buf)).unwrap(); + let date: NaiveDateTime = Decode::::decode(PgValue::from_bytes(&buf)).unwrap(); assert_eq!(date.to_string(), "2000-01-01 01:00:00"); let buf = 629_377_265_000_000i64.to_be_bytes(); - let date: NaiveDateTime = Decode::::decode(PgValue::bytes(TypeId(0), &buf)).unwrap(); + let date: NaiveDateTime = Decode::::decode(PgValue::from_bytes(&buf)).unwrap(); assert_eq!(date.to_string(), "2019-12-11 11:01:05"); } #[test] fn test_encode_date() { - let mut buf = Vec::new(); + let mut buf = PgRawBuffer::default(); let date = NaiveDate::from_ymd(2000, 1, 1); Encode::::encode(&date, &mut buf); - assert_eq!(buf, [0; 4]); + assert_eq!(&**buf, [0; 4]); buf.clear(); let date2 = NaiveDate::from_ymd(2001, 1, 1); Encode::::encode(&date2, &mut buf); // 2000 was a leap year - assert_eq!(buf, 366i32.to_be_bytes()); + assert_eq!(&**buf, 366i32.to_be_bytes()); buf.clear(); let date3 = NaiveDate::from_ymd(2019, 12, 11); Encode::::encode(&date3, &mut buf); - assert_eq!(buf, 7284i32.to_be_bytes()); + assert_eq!(&**buf, 7284i32.to_be_bytes()); buf.clear(); } #[test] fn test_decode_date() { let buf = [0; 4]; - let date: NaiveDate = Decode::::decode(PgValue::bytes(TypeId(0), &buf)).unwrap(); + let date: NaiveDate = Decode::::decode(PgValue::from_bytes(&buf)).unwrap(); assert_eq!(date.to_string(), "2000-01-01"); let buf = 366i32.to_be_bytes(); - let date: NaiveDate = Decode::::decode(PgValue::bytes(TypeId(0), &buf)).unwrap(); + let date: NaiveDate = Decode::::decode(PgValue::from_bytes(&buf)).unwrap(); assert_eq!(date.to_string(), "2001-01-01"); let buf = 7284i32.to_be_bytes(); - let date: NaiveDate = Decode::::decode(PgValue::bytes(TypeId(0), &buf)).unwrap(); + let date: NaiveDate = Decode::::decode(PgValue::from_bytes(&buf)).unwrap(); assert_eq!(date.to_string(), "2019-12-11"); } diff --git a/sqlx-core/src/postgres/types/float.rs b/sqlx-core/src/postgres/types/float.rs index a287790e..413e54dc 100644 --- a/sqlx-core/src/postgres/types/float.rs +++ b/sqlx-core/src/postgres/types/float.rs @@ -6,8 +6,7 @@ use crate::decode::Decode; use crate::encode::Encode; use crate::error::Error; use crate::postgres::protocol::TypeId; -use crate::postgres::types::PgTypeInfo; -use crate::postgres::{PgData, PgValue, Postgres}; +use crate::postgres::{PgData, PgRawBuffer, PgTypeInfo, PgValue, Postgres}; use crate::types::Type; impl Type for f32 { @@ -28,7 +27,7 @@ impl Type for Vec { } impl Encode for f32 { - fn encode(&self, buf: &mut Vec) { + fn encode(&self, buf: &mut PgRawBuffer) { >::encode(&(self.to_bits() as i32), buf) } } @@ -64,7 +63,7 @@ impl Type for Vec { } impl Encode for f64 { - fn encode(&self, buf: &mut Vec) { + fn encode(&self, buf: &mut PgRawBuffer) { >::encode(&(self.to_bits() as i64), buf) } } diff --git a/sqlx-core/src/postgres/types/int.rs b/sqlx-core/src/postgres/types/int.rs index 0c13e11c..0a789e5f 100644 --- a/sqlx-core/src/postgres/types/int.rs +++ b/sqlx-core/src/postgres/types/int.rs @@ -5,8 +5,7 @@ use byteorder::{NetworkEndian, ReadBytesExt}; use crate::decode::Decode; use crate::encode::Encode; use crate::postgres::protocol::TypeId; -use crate::postgres::types::PgTypeInfo; -use crate::postgres::{PgData, PgValue, Postgres}; +use crate::postgres::{PgData, PgRawBuffer, PgTypeInfo, PgValue, Postgres}; use crate::types::Type; use crate::Error; @@ -29,7 +28,7 @@ impl Type for Vec { } impl Encode for i8 { - fn encode(&self, buf: &mut Vec) { + fn encode(&self, buf: &mut PgRawBuffer) { buf.extend_from_slice(&self.to_be_bytes()); } } @@ -62,7 +61,7 @@ impl Type for Vec { } impl Encode for i16 { - fn encode(&self, buf: &mut Vec) { + fn encode(&self, buf: &mut PgRawBuffer) { buf.extend_from_slice(&self.to_be_bytes()); } } @@ -95,7 +94,7 @@ impl Type for Vec { } impl Encode for i32 { - fn encode(&self, buf: &mut Vec) { + fn encode(&self, buf: &mut PgRawBuffer) { buf.extend_from_slice(&self.to_be_bytes()); } } @@ -128,7 +127,7 @@ impl Type for Vec { } impl Encode for u32 { - fn encode(&self, buf: &mut Vec) { + fn encode(&self, buf: &mut PgRawBuffer) { buf.extend_from_slice(&self.to_be_bytes()); } } @@ -160,7 +159,7 @@ impl Type for Vec { } impl Encode for i64 { - fn encode(&self, buf: &mut Vec) { + fn encode(&self, buf: &mut PgRawBuffer) { buf.extend_from_slice(&self.to_be_bytes()); } } diff --git a/sqlx-core/src/postgres/types/ipnetwork.rs b/sqlx-core/src/postgres/types/ipnetwork.rs index 7742c5d8..db4eae17 100644 --- a/sqlx-core/src/postgres/types/ipnetwork.rs +++ b/sqlx-core/src/postgres/types/ipnetwork.rs @@ -5,9 +5,8 @@ use ipnetwork::{IpNetwork, Ipv4Network, Ipv6Network}; use crate::decode::Decode; use crate::encode::Encode; use crate::postgres::protocol::TypeId; -use crate::postgres::types::PgTypeInfo; use crate::postgres::value::PgValue; -use crate::postgres::{PgData, Postgres}; +use crate::postgres::{PgData, PgRawBuffer, PgTypeInfo, Postgres}; use crate::types::Type; use crate::Error; @@ -38,7 +37,7 @@ impl Type for [IpNetwork] { } impl Encode for IpNetwork { - fn encode(&self, buf: &mut Vec) { + fn encode(&self, buf: &mut PgRawBuffer) { match self { IpNetwork::V4(net) => { buf.push(PGSQL_AF_INET); diff --git a/sqlx-core/src/postgres/types/json.rs b/sqlx-core/src/postgres/types/json.rs index 2df3f54a..6b4b85ea 100644 --- a/sqlx-core/src/postgres/types/json.rs +++ b/sqlx-core/src/postgres/types/json.rs @@ -2,7 +2,7 @@ use crate::decode::Decode; use crate::encode::Encode; use crate::io::{Buf, BufMut}; use crate::postgres::protocol::TypeId; -use crate::postgres::{PgData, PgTypeInfo, PgValue, Postgres}; +use crate::postgres::{PgData, PgRawBuffer, PgTypeInfo, PgValue, Postgres}; use crate::types::{Json, Type}; use crate::value::RawValue; use serde::{Deserialize, Serialize}; @@ -22,7 +22,7 @@ impl Type for JsonValue { } impl Encode for JsonValue { - fn encode(&self, buf: &mut Vec) { + fn encode(&self, buf: &mut PgRawBuffer) { Json(self).encode(buf) } } @@ -40,7 +40,7 @@ impl Type for &'_ JsonRawValue { } impl Encode for &'_ JsonRawValue { - fn encode(&self, buf: &mut Vec) { + fn encode(&self, buf: &mut PgRawBuffer) { Json(self).encode(buf) } } @@ -61,11 +61,11 @@ impl Encode for Json where T: Serialize, { - fn encode(&self, buf: &mut Vec) { - // JSONB version (as of 2020-03-20 ) + fn encode(&self, buf: &mut PgRawBuffer) { + // JSONB version (as of 2020-03-20) buf.put_u8(1); - serde_json::to_writer(buf, &self.0) + serde_json::to_writer(&mut **buf, &self.0) .expect("failed to serialize json for encoding to database"); } } @@ -79,7 +79,7 @@ where (match value.try_get()? { PgData::Text(s) => serde_json::from_str(s), PgData::Binary(mut buf) => { - if value.type_info().as_ref().map(|info| info.id) == Some(TypeId::JSONB) { + if value.type_info().as_ref().and_then(|info| info.id) == Some(TypeId::JSONB) { let version = buf.get_u8()?; assert_eq!( diff --git a/sqlx-core/src/postgres/types/mod.rs b/sqlx-core/src/postgres/types/mod.rs index 5c05d198..9a7def68 100644 --- a/sqlx-core/src/postgres/types/mod.rs +++ b/sqlx-core/src/postgres/types/mod.rs @@ -128,14 +128,9 @@ //! a potentially `NULL` value from Postgres. //! -use std::fmt::{self, Debug, Display}; -use std::ops::Deref; -use std::sync::Arc; - use crate::decode::Decode; use crate::postgres::protocol::TypeId; use crate::postgres::{PgValue, Postgres}; -use crate::types::TypeInfo; mod array; mod bool; @@ -167,125 +162,9 @@ mod json; #[cfg(feature = "ipnetwork")] mod ipnetwork; -/// Type information for a Postgres SQL type. -#[derive(Debug, Clone)] -pub struct PgTypeInfo { - pub(crate) id: TypeId, - pub(crate) name: Option, -} - -impl PgTypeInfo { - pub(crate) fn new(id: TypeId, name: impl Into) -> Self { - Self { - id, - name: Some(name.into()), - } - } - - /// Create a `PgTypeInfo` from a type's object identifier. - /// - /// The object identifier of a type can be queried with - /// `SELECT oid FROM pg_type WHERE typname = ;` - pub fn with_oid(oid: u32) -> Self { - Self { - id: TypeId(oid), - name: None, - } - } - - #[doc(hidden)] - pub fn type_feature_gate(&self) -> Option<&'static str> { - match self.id { - TypeId::DATE | TypeId::TIME | TypeId::TIMESTAMP | TypeId::TIMESTAMPTZ => Some("chrono"), - TypeId::UUID => Some("uuid"), - // we can support decoding `PgNumeric` but it's decidedly less useful to the layman - TypeId::NUMERIC => Some("bigdecimal"), - TypeId::CIDR | TypeId::INET => Some("ipnetwork"), - _ => None, - } - } - - #[doc(hidden)] - pub fn oid(&self) -> u32 { - self.id.0 - } -} - -impl Display for PgTypeInfo { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - if let Some(ref name) = self.name { - write!(f, "{}", *name) - } else { - write!(f, "{}", self.id) - } - } -} - -impl PartialEq for PgTypeInfo { - fn eq(&self, other: &PgTypeInfo) -> bool { - // Postgres is strongly typed (mostly) so the rules that make sense here are equivalent - // to the rules that make sense in [compatible] - self.compatible(other) - } -} - -impl TypeInfo for PgTypeInfo { - fn compatible(&self, other: &Self) -> bool { - match (self.id, other.id) { - (TypeId::CIDR, TypeId::INET) - | (TypeId::INET, TypeId::CIDR) - | (TypeId::ARRAY_CIDR, TypeId::ARRAY_INET) - | (TypeId::ARRAY_INET, TypeId::ARRAY_CIDR) => true, - - // the following text-like types are compatible - (TypeId::VARCHAR, other) - | (TypeId::TEXT, other) - | (TypeId::BPCHAR, other) - | (TypeId::NAME, other) - | (TypeId::UNKNOWN, other) - if match other { - TypeId::VARCHAR - | TypeId::TEXT - | TypeId::BPCHAR - | TypeId::NAME - | TypeId::UNKNOWN => true, - _ => false, - } => - { - true - } - - // the following text-like array types are compatible - (TypeId::ARRAY_VARCHAR, other) - | (TypeId::ARRAY_TEXT, other) - | (TypeId::ARRAY_BPCHAR, other) - | (TypeId::ARRAY_NAME, other) - if match other { - TypeId::ARRAY_VARCHAR - | TypeId::ARRAY_TEXT - | TypeId::ARRAY_BPCHAR - | TypeId::ARRAY_NAME => true, - _ => false, - } => - { - true - } - - // JSON <=> JSONB - (TypeId::JSON, other) | (TypeId::JSONB, other) - if match other { - TypeId::JSON | TypeId::JSONB => true, - _ => false, - } => - { - true - } - - _ => self.id.0 == other.id.0, - } - } -} - +// Implement `Decode` for all postgres types +// The concept of a nullable `RawValue` is db-specific +// `Type` is implemented generically at src/types.rs impl<'de, T> Decode<'de, Postgres> for Option where T: Decode<'de, Postgres>, @@ -299,45 +178,82 @@ where } } -/// Copy of `Cow` but for strings; clones guaranteed to be cheap. -#[derive(Clone, Debug)] -pub(crate) enum SharedStr { - Static(&'static str), - Arc(Arc), -} +// Try to resolve a _static_ type name from an OID +pub(crate) fn try_resolve_type_name(oid: u32) -> Option<&'static str> { + Some(match TypeId(oid) { + TypeId::BOOL => "BOOL", -impl Deref for SharedStr { - type Target = str; + TypeId::CHAR => "\"CHAR\"", - fn deref(&self) -> &str { - match self { - SharedStr::Static(s) => s, - SharedStr::Arc(s) => s, + TypeId::INT2 => "INT2", + TypeId::INT4 => "INT4", + TypeId::INT8 => "INT8", + + TypeId::OID => "OID", + + TypeId::FLOAT4 => "FLOAT4", + TypeId::FLOAT8 => "FLOAT8", + + TypeId::NUMERIC => "NUMERIC", + + TypeId::TEXT => "TEXT", + TypeId::VARCHAR => "VARCHAR", + TypeId::BPCHAR => "BPCHAR", + TypeId::UNKNOWN => "UNKNOWN", + TypeId::NAME => "NAME", + + TypeId::DATE => "DATE", + TypeId::TIME => "TIME", + TypeId::TIMESTAMP => "TIMESTAMP", + TypeId::TIMESTAMPTZ => "TIMESTAMPTZ", + + TypeId::BYTEA => "BYTEA", + + TypeId::UUID => "UUID", + + TypeId::CIDR => "CIDR", + TypeId::INET => "INET", + + TypeId::ARRAY_BOOL => "BOOL[]", + + TypeId::ARRAY_CHAR => "\"CHAR\"[]", + + TypeId::ARRAY_INT2 => "INT2[]", + TypeId::ARRAY_INT4 => "INT4[]", + TypeId::ARRAY_INT8 => "INT8[]", + + TypeId::ARRAY_OID => "OID[]", + + TypeId::ARRAY_FLOAT4 => "FLOAT4[]", + TypeId::ARRAY_FLOAT8 => "FLOAT8[]", + + TypeId::ARRAY_TEXT => "TEXT[]", + TypeId::ARRAY_VARCHAR => "VARCHAR[]", + TypeId::ARRAY_BPCHAR => "BPCHAR[]", + TypeId::ARRAY_NAME => "NAME[]", + + TypeId::ARRAY_NUMERIC => "NUMERIC[]", + + TypeId::ARRAY_DATE => "DATE[]", + TypeId::ARRAY_TIME => "TIME[]", + TypeId::ARRAY_TIMESTAMP => "TIMESTAMP[]", + TypeId::ARRAY_TIMESTAMPTZ => "TIMESTAMPTZ[]", + + TypeId::ARRAY_BYTEA => "BYTEA[]", + + TypeId::ARRAY_UUID => "UUID[]", + + TypeId::ARRAY_CIDR => "CIDR[]", + TypeId::ARRAY_INET => "INET[]", + + TypeId::JSON => "JSON", + TypeId::JSONB => "JSONB", + + TypeId::RECORD => "RECORD", + TypeId::ARRAY_RECORD => "RECORD[]", + + _ => { + return None; } - } -} - -impl<'a> From<&'a SharedStr> for SharedStr { - fn from(s: &'a SharedStr) -> Self { - s.clone() - } -} - -impl From<&'static str> for SharedStr { - fn from(s: &'static str) -> Self { - SharedStr::Static(s) - } -} - -impl From for SharedStr { - #[inline] - fn from(s: String) -> Self { - SharedStr::Arc(s.into()) - } -} - -impl fmt::Display for SharedStr { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - fmt.pad(self) - } + }) } diff --git a/sqlx-core/src/postgres/types/raw/array.rs b/sqlx-core/src/postgres/types/raw/array.rs index 6d571395..a9b7d647 100644 --- a/sqlx-core/src/postgres/types/raw/array.rs +++ b/sqlx-core/src/postgres/types/raw/array.rs @@ -2,7 +2,7 @@ use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::io::{Buf, BufMut}; use crate::postgres::types::raw::sequence::PgSequenceDecoder; -use crate::postgres::{PgData, PgValue, Postgres}; +use crate::postgres::{PgData, PgRawBuffer, PgValue, Postgres}; use crate::types::Type; use byteorder::BE; use std::marker::PhantomData; @@ -13,16 +13,15 @@ use std::marker::PhantomData; pub(crate) struct PgArrayEncoder<'enc, T> { count: usize, len_start_index: usize, - buf: &'enc mut Vec, + buf: &'enc mut PgRawBuffer, phantom: PhantomData, } impl<'enc, T> PgArrayEncoder<'enc, T> where - T: Encode, - T: Type, + T: Encode + Type, { - pub(crate) fn new(buf: &'enc mut Vec) -> Self { + pub(crate) fn new(buf: &'enc mut PgRawBuffer) -> Self { let ty = >::type_info(); // ndim @@ -31,8 +30,15 @@ where // dataoffset buf.put_i32::(0); - // elemtype - buf.put_i32::(ty.id.0 as i32); + // [elemtype] element type OID + if let Some(oid) = ty.id { + // write oid + buf.extend(&oid.0.to_be_bytes()); + } else { + // write hole for this oid + buf.push_type_hole(&ty.name); + } + let len_start_index = buf.len(); // dimensions @@ -96,14 +102,15 @@ where pub(crate) fn new(value: PgValue<'de>) -> crate::Result { let mut data = value.try_get()?; - match data { + let element_oid = match data { PgData::Binary(ref mut buf) => { // number of dimensions of the array let ndim = buf.get_i32::()?; if ndim == 0 { + // ndim of 0 is an empty array return Ok(Self { - inner: PgSequenceDecoder::new(PgData::Binary(&[]), false), + inner: PgSequenceDecoder::new(PgData::Binary(&[]), None), phantom: PhantomData, }); } @@ -119,12 +126,8 @@ where // this doesn't matter as the data is always at the end of the header let _dataoffset = buf.get_i32::()?; - // TODO: Validate element type with whatever framework is put in place to do so - // As a reminder, we have no way to do this yet and still account for [compatible] - // types. - // element type OID - let _elemtype = buf.get_i32::()?; + let element_oid = buf.get_u32::()?; // length of each array axis let _dimensions = buf.get_i32::()?; @@ -138,13 +141,15 @@ where lower_bnds )); } + + Some(element_oid) } - PgData::Text(_) => {} - } + PgData::Text(_) => None, + }; Ok(Self { - inner: PgSequenceDecoder::new(data, false), + inner: PgSequenceDecoder::new(data, element_oid), phantom: PhantomData, }) } @@ -171,14 +176,13 @@ where mod tests { use super::PgArrayDecoder; use super::PgArrayEncoder; - use crate::postgres::protocol::TypeId; - use crate::postgres::PgValue; + use crate::postgres::{PgRawBuffer, PgValue, Postgres}; const BUF_BINARY_I32: &[u8] = b"\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x17\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x02\x00\x00\x00\x04\x00\x00\x00\x03\x00\x00\x00\x04\x00\x00\x00\x04"; #[test] fn it_encodes_i32() { - let mut buf = Vec::new(); + let mut buf = PgRawBuffer::default(); let mut encoder = PgArrayEncoder::new(&mut buf); for val in &[1_i32, 2, 3, 4] { @@ -187,13 +191,13 @@ mod tests { encoder.finish(); - assert_eq!(buf, BUF_BINARY_I32); + assert_eq!(&**buf, BUF_BINARY_I32); } #[test] fn it_decodes_text_i32() -> crate::Result<()> { let s = "{1,152,-12412}"; - let mut decoder = PgArrayDecoder::::new(PgValue::str(TypeId(0), s))?; + let mut decoder = PgArrayDecoder::::new(PgValue::from_str(s))?; assert_eq!(decoder.decode()?, Some(1)); assert_eq!(decoder.decode()?, Some(152)); @@ -206,7 +210,7 @@ mod tests { #[test] fn it_decodes_text_str() -> crate::Result<()> { let s = "{\"\",\"\\\"\"}"; - let mut decoder = PgArrayDecoder::::new(PgValue::str(TypeId(0), s))?; + let mut decoder = PgArrayDecoder::::new(PgValue::from_str(s))?; assert_eq!(decoder.decode()?, Some("".to_string())); assert_eq!(decoder.decode()?, Some("\"".to_string())); @@ -217,7 +221,7 @@ mod tests { #[test] fn it_decodes_binary_nulls() -> crate::Result<()> { - let mut decoder = PgArrayDecoder::>::new(PgValue::bytes(TypeId(0), + let mut decoder = PgArrayDecoder::>::new(PgValue::from_bytes( b"\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x10\x00\x00\x00\x04\x00\x00\x00\x01\xff\xff\xff\xff\x00\x00\x00\x01\x01\xff\xff\xff\xff\x00\x00\x00\x01\x00", ))?; @@ -231,7 +235,7 @@ mod tests { #[test] fn it_decodes_binary_i32() -> crate::Result<()> { - let mut decoder = PgArrayDecoder::::new(PgValue::bytes(TypeId(0), BUF_BINARY_I32))?; + let mut decoder = PgArrayDecoder::::new(PgValue::from_bytes(BUF_BINARY_I32))?; let val_1 = decoder.decode()?; let val_2 = decoder.decode()?; diff --git a/sqlx-core/src/postgres/types/raw/numeric.rs b/sqlx-core/src/postgres/types/raw/numeric.rs index 829c7a7f..fe2163a5 100644 --- a/sqlx-core/src/postgres/types/raw/numeric.rs +++ b/sqlx-core/src/postgres/types/raw/numeric.rs @@ -6,7 +6,7 @@ use crate::decode::Decode; use crate::encode::Encode; use crate::io::{Buf, BufMut}; use crate::postgres::protocol::TypeId; -use crate::postgres::{PgData, PgTypeInfo, PgValue, Postgres}; +use crate::postgres::{PgData, PgRawBuffer, PgTypeInfo, PgValue, Postgres}; use crate::types::Type; use crate::Error; @@ -125,7 +125,7 @@ impl Decode<'_, Postgres> for PgNumeric { /// * If `digits.len()` overflows `i16` /// * If any element in `digits` is greater than or equal to 10000 impl Encode for PgNumeric { - fn encode(&self, buf: &mut Vec) { + fn encode(&self, buf: &mut PgRawBuffer) { match *self { PgNumeric::Number { ref digits, diff --git a/sqlx-core/src/postgres/types/raw/record.rs b/sqlx-core/src/postgres/types/raw/record.rs index b04847c6..40867177 100644 --- a/sqlx-core/src/postgres/types/raw/record.rs +++ b/sqlx-core/src/postgres/types/raw/record.rs @@ -2,18 +2,18 @@ use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::io::Buf; use crate::postgres::types::raw::sequence::PgSequenceDecoder; -use crate::postgres::{PgData, PgValue, Postgres}; +use crate::postgres::{PgData, PgRawBuffer, PgValue, Postgres}; use crate::types::Type; use byteorder::BigEndian; pub struct PgRecordEncoder<'a> { - buf: &'a mut Vec, + buf: &'a mut PgRawBuffer, beg: usize, num: u32, } impl<'a> PgRecordEncoder<'a> { - pub fn new(buf: &'a mut Vec) -> Self { + pub fn new(buf: &'a mut PgRawBuffer) -> Self { // reserve space for a field count buf.extend_from_slice(&(0_u32).to_be_bytes()); @@ -33,9 +33,15 @@ impl<'a> PgRecordEncoder<'a> { where T: Type + Encode, { - // write oid let info = T::type_info(); - self.buf.extend(&info.oid().to_be_bytes()); + + if let Some(oid) = info.id { + // write oid + self.buf.extend(&oid.0.to_be_bytes()); + } else { + // write hole for this oid + self.buf.push_type_hole(&info.name); + } // write zeros for length self.buf.extend(&[0; 4]); @@ -71,7 +77,7 @@ impl<'de> PgRecordDecoder<'de> { } } - Ok(Self(PgSequenceDecoder::new(data, true))) + Ok(Self(PgSequenceDecoder::new(data, None))) } #[inline] @@ -91,15 +97,15 @@ fn test_encode_field() { use std::convert::TryInto; let value = "Foo Bar"; - let mut raw_encoded = Vec::new(); + let mut raw_encoded = PgRawBuffer::default(); <&str as Encode>::encode(&value, &mut raw_encoded); - let mut field_encoded = Vec::new(); + let mut field_encoded = PgRawBuffer::default(); let mut encoder = PgRecordEncoder::new(&mut field_encoded); encoder.encode(&value); // check oid - let oid = <&str as Type>::type_info().oid(); + let oid = <&str as Type>::type_info().id.unwrap().0; let field_encoded_oid = u32::from_be_bytes(field_encoded[4..8].try_into().unwrap()); assert_eq!(oid, field_encoded_oid); @@ -108,21 +114,19 @@ fn test_encode_field() { assert_eq!(raw_encoded.len(), field_encoded_length as usize); // check data - assert_eq!(raw_encoded, &field_encoded[12..]); + assert_eq!(&**raw_encoded, &field_encoded[12..]); } #[test] fn test_decode_field() { - use crate::postgres::protocol::TypeId; - let value = "Foo Bar".to_string(); - let mut buf = Vec::new(); + let mut buf = PgRawBuffer::default(); let mut encoder = PgRecordEncoder::new(&mut buf); encoder.encode(&value); let buf = buf.as_slice(); - let mut decoder = PgRecordDecoder::new(PgValue::bytes(TypeId(0), buf)).unwrap(); + let mut decoder = PgRecordDecoder::new(PgValue::from_bytes(buf)).unwrap(); let value_decoded: String = decoder.decode().unwrap(); assert_eq!(value_decoded, value); diff --git a/sqlx-core/src/postgres/types/raw/sequence.rs b/sqlx-core/src/postgres/types/raw/sequence.rs index ae4d44a0..57b7f710 100644 --- a/sqlx-core/src/postgres/types/raw/sequence.rs +++ b/sqlx-core/src/postgres/types/raw/sequence.rs @@ -8,11 +8,14 @@ use byteorder::BigEndian; pub(crate) struct PgSequenceDecoder<'de> { data: PgData<'de>, len: usize, - mixed: bool, + is_text_record: bool, + element_oid: Option, } impl<'de> PgSequenceDecoder<'de> { - pub(crate) fn new(mut data: PgData<'de>, mixed: bool) -> Self { + pub(crate) fn new(mut data: PgData<'de>, element_oid: Option) -> Self { + let mut is_text_record = false; + match data { PgData::Binary(_) => { // assume that this has already gotten tweaked by the caller as @@ -20,14 +23,16 @@ impl<'de> PgSequenceDecoder<'de> { } PgData::Text(ref mut s) => { + is_text_record = s.as_bytes()[0] == b'('; // remove the outer ( ... ) or { ... } *s = &s[1..(s.len() - 1)]; } } Self { + is_text_record, + element_oid, data, - mixed, len: 0, } } @@ -47,34 +52,34 @@ impl<'de> PgSequenceDecoder<'de> { return Ok(None); } - // mixed sequences can contain values of many different types - // the OID of the type is encoded next to each value - let type_id = if self.mixed { - let oid = buf.get_u32::()?; - let expected_ty = PgTypeInfo::with_oid(oid); + let type_info = if let Some(element_oid) = self.element_oid { + // NOTE: We don't validate the element type for non-mixed sequences because + // the outer type like `text[]` would have already ensured we are dealing + // with a Vec + PgTypeInfo::new(TypeId(element_oid), "") + } else { + // mixed sequences can contain values of many different types + // the OID of the type is encoded next to each value + let element_oid = buf.get_u32::()?; + let expected_ty = PgTypeInfo::new(TypeId(element_oid), ""); if !expected_ty.compatible(&T::type_info()) { return Err(crate::Error::mismatched_types::(expected_ty)); } - TypeId(oid) - } else { - // NOTE: We don't validate the element type for non-mixed sequences because - // the outer type like `text[]` would have already ensured we are dealing - // with a Vec - T::type_info().id + expected_ty }; let len = buf.get_i32::()? as isize; let value = if len < 0 { - T::decode(PgValue::null(type_id))? + T::decode(PgValue::null())? } else { let value_buf = &buf[..(len as usize)]; *buf = &buf[(len as usize)..]; - T::decode(PgValue::bytes(type_id, value_buf))? + T::decode(PgValue::bytes(type_info, value_buf))? }; self.len += 1; @@ -146,12 +151,12 @@ impl<'de> PgSequenceDecoder<'de> { // we could use. In TEXT mode, sequences aren't typed. let value = T::decode(if end == Some(0) { - PgValue::null(TypeId(0)) - } else if !self.mixed && value == "NULL" { + PgValue::null() + } else if !self.is_text_record && value == "NULL" { // Yes, in arrays the text encoding of a NULL is just NULL - PgValue::null(TypeId(0)) + PgValue::null() } else { - PgValue::str(TypeId(0), &*value) + PgValue::from_str(&*value) })?; *s = if let Some(end) = end { @@ -168,9 +173,10 @@ impl<'de> PgSequenceDecoder<'de> { } } +#[cfg(test)] impl<'de> From<&'de str> for PgSequenceDecoder<'de> { fn from(s: &'de str) -> Self { - Self::new(PgData::Text(s), false) + Self::new(PgData::Text(s), None) } } diff --git a/sqlx-core/src/postgres/types/record.rs b/sqlx-core/src/postgres/types/record.rs index 09bba756..9d2fcd32 100644 --- a/sqlx-core/src/postgres/types/record.rs +++ b/sqlx-core/src/postgres/types/record.rs @@ -1,7 +1,7 @@ use crate::decode::Decode; use crate::postgres::protocol::TypeId; +use crate::postgres::type_info::PgTypeInfo; use crate::postgres::types::raw::PgRecordDecoder; -use crate::postgres::types::PgTypeInfo; use crate::postgres::value::PgValue; use crate::postgres::Postgres; use crate::types::Type; @@ -11,20 +11,14 @@ macro_rules! impl_pg_record_for_tuple { impl<$($T,)+> Type for ($($T,)+) { #[inline] fn type_info() -> PgTypeInfo { - PgTypeInfo { - id: TypeId(2249), - name: Some("RECORD".into()), - } + PgTypeInfo::new(TypeId::RECORD, "RECORD") } } impl<$($T,)+> Type for [($($T,)+)] { #[inline] fn type_info() -> PgTypeInfo { - PgTypeInfo { - id: TypeId(2287), - name: Some("RECORD[]".into()), - } + PgTypeInfo::new(TypeId::ARRAY_RECORD, "RECORD[]") } } diff --git a/sqlx-core/src/postgres/types/str.rs b/sqlx-core/src/postgres/types/str.rs index 00f9376d..f70e3d12 100644 --- a/sqlx-core/src/postgres/types/str.rs +++ b/sqlx-core/src/postgres/types/str.rs @@ -3,9 +3,7 @@ use std::str::from_utf8; use crate::decode::Decode; use crate::encode::Encode; use crate::postgres::protocol::TypeId; -use crate::postgres::types::PgTypeInfo; -use crate::postgres::value::{PgData, PgValue}; -use crate::postgres::Postgres; +use crate::postgres::{PgData, PgRawBuffer, PgTypeInfo, PgValue, Postgres}; use crate::types::Type; use crate::Error; @@ -46,7 +44,7 @@ impl Type for Vec { } impl Encode for str { - fn encode(&self, buf: &mut Vec) { + fn encode(&self, buf: &mut PgRawBuffer) { buf.extend_from_slice(self.as_bytes()); } @@ -56,7 +54,7 @@ impl Encode for str { } impl Encode for String { - fn encode(&self, buf: &mut Vec) { + fn encode(&self, buf: &mut PgRawBuffer) { >::encode(self.as_str(), buf) } diff --git a/sqlx-core/src/postgres/types/time.rs b/sqlx-core/src/postgres/types/time.rs index f27019cf..ff6e3606 100644 --- a/sqlx-core/src/postgres/types/time.rs +++ b/sqlx-core/src/postgres/types/time.rs @@ -9,8 +9,7 @@ use crate::decode::Decode; use crate::encode::Encode; use crate::io::Buf; use crate::postgres::protocol::TypeId; -use crate::postgres::types::PgTypeInfo; -use crate::postgres::{PgData, PgValue, Postgres}; +use crate::postgres::{PgData, PgRawBuffer, PgTypeInfo, PgValue, Postgres}; use crate::types::Type; const POSTGRES_EPOCH: PrimitiveDateTime = date!(2000 - 1 - 1).midnight(); @@ -135,7 +134,7 @@ impl<'de> Decode<'de, Postgres> for Time { } impl Encode for Time { - fn encode(&self, buf: &mut Vec) { + fn encode(&self, buf: &mut PgRawBuffer) { let micros = microseconds_since_midnight(*self); Encode::::encode(µs, buf); @@ -161,7 +160,7 @@ impl<'de> Decode<'de, Postgres> for Date { } impl Encode for Date { - fn encode(&self, buf: &mut Vec) { + fn encode(&self, buf: &mut PgRawBuffer) { let days: i32 = (*self - date!(2000 - 1 - 1)) .whole_days() .try_into() @@ -218,7 +217,7 @@ impl<'de> Decode<'de, Postgres> for PrimitiveDateTime { } impl Encode for PrimitiveDateTime { - fn encode(&self, buf: &mut Vec) { + fn encode(&self, buf: &mut PgRawBuffer) { let micros: i64 = (*self - POSTGRES_EPOCH) .whole_microseconds() .try_into() @@ -241,7 +240,7 @@ impl<'de> Decode<'de, Postgres> for OffsetDateTime { } impl Encode for OffsetDateTime { - fn encode(&self, buf: &mut Vec) { + fn encode(&self, buf: &mut PgRawBuffer) { let utc_dt = self.to_offset(offset!(UTC)); let primitive_dt = PrimitiveDateTime::new(utc_dt.date(), utc_dt.time()); @@ -258,91 +257,88 @@ use time::time; #[test] fn test_encode_time() { - let mut buf = Vec::new(); + let mut buf = PgRawBuffer::default(); Encode::::encode(&time!(0:00), &mut buf); - assert_eq!(buf, [0; 8]); + assert_eq!(&**buf, [0; 8]); buf.clear(); // one second Encode::::encode(&time!(0:00:01), &mut buf); - assert_eq!(buf, 1_000_000i64.to_be_bytes()); + assert_eq!(&**buf, 1_000_000i64.to_be_bytes()); buf.clear(); // two hours Encode::::encode(&time!(2:00), &mut buf); let expected = 1_000_000i64 * 60 * 60 * 2; - assert_eq!(buf, expected.to_be_bytes()); + assert_eq!(&**buf, expected.to_be_bytes()); buf.clear(); // 3:14:15.000001 Encode::::encode(&time!(3:14:15.000001), &mut buf); let expected = 1_000_000i64 * 60 * 60 * 3 + 1_000_000i64 * 60 * 14 + 1_000_000i64 * 15 + 1; - assert_eq!(buf, expected.to_be_bytes()); + assert_eq!(&**buf, expected.to_be_bytes()); buf.clear(); } #[test] fn test_decode_time() { let buf = [0u8; 8]; - let time: Time = Decode::::decode(PgValue::bytes(TypeId(0), &buf)).unwrap(); + let time: Time = Decode::::decode(PgValue::from_bytes(&buf)).unwrap(); assert_eq!(time, time!(0:00)); // half an hour let buf = (1_000_000i64 * 60 * 30).to_be_bytes(); - let time: Time = Decode::::decode(PgValue::bytes(TypeId(0), &buf)).unwrap(); + let time: Time = Decode::::decode(PgValue::from_bytes(&buf)).unwrap(); assert_eq!(time, time!(0:30)); // 12:53:05.125305 let buf = (1_000_000i64 * 60 * 60 * 12 + 1_000_000i64 * 60 * 53 + 1_000_000i64 * 5 + 125305) .to_be_bytes(); - let time: Time = Decode::::decode(PgValue::bytes(TypeId(0), &buf)).unwrap(); + let time: Time = Decode::::decode(PgValue::from_bytes(&buf)).unwrap(); assert_eq!(time, time!(12:53:05.125305)); } #[test] fn test_encode_datetime() { - let mut buf = Vec::new(); + let mut buf = PgRawBuffer::default(); Encode::::encode(&POSTGRES_EPOCH, &mut buf); - assert_eq!(buf, [0; 8]); + assert_eq!(&**buf, [0; 8]); buf.clear(); // one hour past epoch let date = POSTGRES_EPOCH + 1.hours(); Encode::::encode(&date, &mut buf); - assert_eq!(buf, 3_600_000_000i64.to_be_bytes()); + assert_eq!(&**buf, 3_600_000_000i64.to_be_bytes()); buf.clear(); // some random date let date = PrimitiveDateTime::new(date!(2019 - 12 - 11), time!(11:01:05)); let expected = (date - POSTGRES_EPOCH).whole_microseconds() as i64; Encode::::encode(&date, &mut buf); - assert_eq!(buf, expected.to_be_bytes()); + assert_eq!(&**buf, expected.to_be_bytes()); buf.clear(); } #[test] fn test_decode_datetime() { let buf = [0u8; 8]; - let date: PrimitiveDateTime = - Decode::::decode(PgValue::bytes(TypeId(0), &buf)).unwrap(); + let date: PrimitiveDateTime = Decode::::decode(PgValue::from_bytes(&buf)).unwrap(); assert_eq!( date, PrimitiveDateTime::new(date!(2000 - 01 - 01), time!(00:00:00)) ); let buf = 3_600_000_000i64.to_be_bytes(); - let date: PrimitiveDateTime = - Decode::::decode(PgValue::bytes(TypeId(0), &buf)).unwrap(); + let date: PrimitiveDateTime = Decode::::decode(PgValue::from_bytes(&buf)).unwrap(); assert_eq!( date, PrimitiveDateTime::new(date!(2000 - 01 - 01), time!(01:00:00)) ); let buf = 629_377_265_000_000i64.to_be_bytes(); - let date: PrimitiveDateTime = - Decode::::decode(PgValue::bytes(TypeId(0), &buf)).unwrap(); + let date: PrimitiveDateTime = Decode::::decode(PgValue::from_bytes(&buf)).unwrap(); assert_eq!( date, PrimitiveDateTime::new(date!(2019 - 12 - 11), time!(11:01:05)) @@ -351,16 +347,16 @@ fn test_decode_datetime() { #[test] fn test_encode_offsetdatetime() { - let mut buf = Vec::new(); + let mut buf = PgRawBuffer::default(); Encode::::encode(&POSTGRES_EPOCH.assume_utc(), &mut buf); - assert_eq!(buf, [0; 8]); + assert_eq!(&**buf, [0; 8]); buf.clear(); // one hour past epoch in MSK (2 hours before epoch in UTC) let date = (POSTGRES_EPOCH + 1.hours()).assume_offset(offset!(+3)); Encode::::encode(&date, &mut buf); - assert_eq!(buf, (-7_200_000_000i64).to_be_bytes()); + assert_eq!(&**buf, (-7_200_000_000i64).to_be_bytes()); buf.clear(); // some random date in MSK @@ -368,28 +364,28 @@ fn test_encode_offsetdatetime() { PrimitiveDateTime::new(date!(2019 - 12 - 11), time!(11:01:05)).assume_offset(offset!(+3)); let expected = (date - POSTGRES_EPOCH.assume_utc()).whole_microseconds() as i64; Encode::::encode(&date, &mut buf); - assert_eq!(buf, expected.to_be_bytes()); + assert_eq!(&**buf, expected.to_be_bytes()); buf.clear(); } #[test] fn test_decode_offsetdatetime() { let buf = [0u8; 8]; - let date: OffsetDateTime = Decode::::decode(PgValue::bytes(TypeId(0), &buf)).unwrap(); + let date: OffsetDateTime = Decode::::decode(PgValue::from_bytes(&buf)).unwrap(); assert_eq!( date, PrimitiveDateTime::new(date!(2000 - 01 - 01), time!(00:00:00)).assume_utc() ); let buf = 3_600_000_000i64.to_be_bytes(); - let date: OffsetDateTime = Decode::::decode(PgValue::bytes(TypeId(0), &buf)).unwrap(); + let date: OffsetDateTime = Decode::::decode(PgValue::from_bytes(&buf)).unwrap(); assert_eq!( date, PrimitiveDateTime::new(date!(2000 - 01 - 01), time!(01:00:00)).assume_utc() ); let buf = 629_377_265_000_000i64.to_be_bytes(); - let date: OffsetDateTime = Decode::::decode(PgValue::bytes(TypeId(0), &buf)).unwrap(); + let date: OffsetDateTime = Decode::::decode(PgValue::from_bytes(&buf)).unwrap(); assert_eq!( date, PrimitiveDateTime::new(date!(2019 - 12 - 11), time!(11:01:05)).assume_utc() @@ -398,36 +394,36 @@ fn test_decode_offsetdatetime() { #[test] fn test_encode_date() { - let mut buf = Vec::new(); + let mut buf = PgRawBuffer::default(); let date = date!(2000 - 1 - 1); Encode::::encode(&date, &mut buf); - assert_eq!(buf, [0; 4]); + assert_eq!(&**buf, [0; 4]); buf.clear(); let date = date!(2001 - 1 - 1); Encode::::encode(&date, &mut buf); // 2000 was a leap year - assert_eq!(buf, 366i32.to_be_bytes()); + assert_eq!(&**buf, 366i32.to_be_bytes()); buf.clear(); let date = date!(2019 - 12 - 11); Encode::::encode(&date, &mut buf); - assert_eq!(buf, 7284i32.to_be_bytes()); + assert_eq!(&**buf, 7284i32.to_be_bytes()); buf.clear(); } #[test] fn test_decode_date() { let buf = [0; 4]; - let date: Date = Decode::::decode(PgValue::bytes(TypeId(0), &buf)).unwrap(); + let date: Date = Decode::::decode(PgValue::from_bytes(&buf)).unwrap(); assert_eq!(date, date!(2000 - 01 - 01)); let buf = 366i32.to_be_bytes(); - let date: Date = Decode::::decode(PgValue::bytes(TypeId(0), &buf)).unwrap(); + let date: Date = Decode::::decode(PgValue::from_bytes(&buf)).unwrap(); assert_eq!(date, date!(2001 - 01 - 01)); let buf = 7284i32.to_be_bytes(); - let date: Date = Decode::::decode(PgValue::bytes(TypeId(0), &buf)).unwrap(); + let date: Date = Decode::::decode(PgValue::from_bytes(&buf)).unwrap(); assert_eq!(date, date!(2019 - 12 - 11)); } diff --git a/sqlx-core/src/postgres/types/uuid.rs b/sqlx-core/src/postgres/types/uuid.rs index 3b857e5b..7cc47b15 100644 --- a/sqlx-core/src/postgres/types/uuid.rs +++ b/sqlx-core/src/postgres/types/uuid.rs @@ -5,9 +5,8 @@ use uuid::Uuid; use crate::decode::Decode; use crate::encode::Encode; use crate::postgres::protocol::TypeId; -use crate::postgres::types::PgTypeInfo; use crate::postgres::value::{PgData, PgValue}; -use crate::postgres::Postgres; +use crate::postgres::{PgRawBuffer, PgTypeInfo, Postgres}; use crate::types::Type; impl Type for Uuid { @@ -29,7 +28,7 @@ impl Type for Vec { } impl Encode for Uuid { - fn encode(&self, buf: &mut Vec) { + fn encode(&self, buf: &mut PgRawBuffer) { buf.extend_from_slice(self.as_bytes()); } } diff --git a/sqlx-core/src/postgres/value.rs b/sqlx-core/src/postgres/value.rs index e582f686..61d00bd6 100644 --- a/sqlx-core/src/postgres/value.rs +++ b/sqlx-core/src/postgres/value.rs @@ -1,5 +1,4 @@ use crate::error::UnexpectedNullError; -use crate::postgres::protocol::TypeId; use crate::postgres::{PgTypeInfo, Postgres}; use crate::value::RawValue; use std::str::from_utf8; @@ -12,7 +11,7 @@ pub enum PgData<'c> { #[derive(Debug)] pub struct PgValue<'c> { - type_id: TypeId, + type_info: Option, data: Option>, } @@ -33,30 +32,38 @@ impl<'c> PgValue<'c> { self.data } - pub(crate) fn null(type_id: TypeId) -> Self { + pub(crate) fn null() -> Self { Self { - type_id, + type_info: None, data: None, } } - pub(crate) fn bytes(type_id: TypeId, buf: &'c [u8]) -> Self { + pub(crate) fn bytes(type_info: PgTypeInfo, buf: &'c [u8]) -> Self { Self { - type_id, + type_info: Some(type_info), data: Some(PgData::Binary(buf)), } } - pub(crate) fn utf8(type_id: TypeId, buf: &'c [u8]) -> crate::Result { + pub(crate) fn utf8(type_info: PgTypeInfo, buf: &'c [u8]) -> crate::Result { Ok(Self { - type_id, + type_info: Some(type_info), data: Some(PgData::Text(from_utf8(&buf).map_err(crate::Error::decode)?)), }) } - pub(crate) fn str(type_id: TypeId, s: &'c str) -> Self { + #[cfg(test)] + pub(crate) fn from_bytes(buf: &'c [u8]) -> Self { Self { - type_id, + type_info: None, + data: Some(PgData::Binary(buf)), + } + } + + pub(crate) fn from_str(s: &'c str) -> Self { + Self { + type_info: None, data: Some(PgData::Text(s)), } } @@ -66,8 +73,8 @@ impl<'c> RawValue<'c> for PgValue<'c> { type Database = Postgres; fn type_info(&self) -> Option { - if self.data.is_some() { - Some(PgTypeInfo::with_oid(self.type_id.0)) + if let (Some(type_info), Some(_)) = (&self.type_info, &self.data) { + Some(type_info.clone()) } else { None } diff --git a/sqlx-macros/src/derives/attributes.rs b/sqlx-macros/src/derives/attributes.rs index 3c3d1f66..676000bb 100644 --- a/sqlx-macros/src/derives/attributes.rs +++ b/sqlx-macros/src/derives/attributes.rs @@ -33,7 +33,7 @@ pub enum RenameAll { pub struct SqlxContainerAttributes { pub transparent: bool, - pub postgres_oid: Option, + pub rename: Option, pub rename_all: Option, pub repr: Option, } @@ -44,8 +44,8 @@ pub struct SqlxChildAttributes { pub fn parse_container_attributes(input: &[Attribute]) -> syn::Result { let mut transparent = None; - let mut postgres_oid = None; let mut repr = None; + let mut rename = None; let mut rename_all = None; for attr in input { @@ -75,20 +75,11 @@ pub fn parse_container_attributes(input: &[Attribute]) -> syn::Result { - for value in list.nested.iter() { - match value { - NestedMeta::Meta(Meta::NameValue(MetaNameValue { - path, - lit: Lit::Int(val), - .. - })) if path.is_ident("oid") => { - try_set!(postgres_oid, val.base10_parse()?, value); - } - u => fail!(u, "unexpected value"), - } - } - } + Meta::NameValue(MetaNameValue { + path, + lit: Lit::Str(val), + .. + }) if path.is_ident("rename") => try_set!(rename, val.value(), value), u => fail!(u, "unexpected attribute"), }, @@ -113,8 +104,8 @@ pub fn parse_container_attributes(input: &[Attribute]) -> syn::Result syn:: input ); - #[cfg(feature = "postgres")] - assert_attribute!( - attributes.postgres_oid.is_none(), - "unexpected #[sqlx(postgres(oid = ..))]", - input - ); - assert_attribute!( attributes.rename_all.is_none(), "unexpected #[sqlx(rename_all = ..)]", @@ -204,13 +188,6 @@ pub fn check_weak_enum_attributes( ) -> syn::Result { let attributes = check_enum_attributes(input)?; - #[cfg(feature = "postgres")] - assert_attribute!( - attributes.postgres_oid.is_none(), - "unexpected #[sqlx(postgres(oid = ..))]", - input - ); - assert_attribute!(attributes.repr.is_some(), "expected #[repr(..)]", input); assert_attribute!( @@ -238,13 +215,6 @@ pub fn check_strong_enum_attributes( ) -> syn::Result { let attributes = check_enum_attributes(input)?; - #[cfg(feature = "postgres")] - assert_attribute!( - attributes.postgres_oid.is_some(), - "expected #[sqlx(postgres(oid = ..))]", - input - ); - assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input); Ok(attributes) @@ -262,13 +232,6 @@ pub fn check_struct_attributes<'a>( input ); - #[cfg(feature = "postgres")] - assert_attribute!( - attributes.postgres_oid.is_some(), - "expected #[sqlx(postgres(oid = ..))]", - input - ); - assert_attribute!( attributes.rename_all.is_none(), "unexpected #[sqlx(rename_all = ..)]", diff --git a/sqlx-macros/src/derives/type.rs b/sqlx-macros/src/derives/type.rs index 07b12cbc..d72a5511 100644 --- a/sqlx-macros/src/derives/type.rs +++ b/sqlx-macros/src/derives/type.rs @@ -117,11 +117,12 @@ fn expand_derive_has_sql_type_strong_enum( } if cfg!(feature = "postgres") { - let oid = attributes.postgres_oid.unwrap(); + let ty_name = attributes.rename.unwrap_or_else(|| ident.to_string()); + tts.extend(quote!( impl sqlx::Type< sqlx::Postgres > for #ident { fn type_info() -> sqlx::postgres::PgTypeInfo { - sqlx::postgres::PgTypeInfo::with_oid(#oid) + sqlx::postgres::PgTypeInfo::with_name(#ty_name) } } )); @@ -150,11 +151,12 @@ fn expand_derive_has_sql_type_struct( let mut tts = proc_macro2::TokenStream::new(); if cfg!(feature = "postgres") { - let oid = attributes.postgres_oid.unwrap(); + let ty_name = attributes.rename.unwrap_or_else(|| ident.to_string()); + tts.extend(quote!( impl sqlx::types::Type< sqlx::Postgres > for #ident { fn type_info() -> sqlx::postgres::PgTypeInfo { - sqlx::postgres::PgTypeInfo::with_oid(#oid) + sqlx::postgres::PgTypeInfo::with_name(#ty_name) } } )); diff --git a/tests/postgres-derives.rs b/tests/postgres-derives.rs index 3cd4a664..d3266c73 100644 --- a/tests/postgres-derives.rs +++ b/tests/postgres-derives.rs @@ -18,7 +18,7 @@ enum Weak { // "Strong" enums can map to TEXT (25) or a custom enum type #[derive(PartialEq, Debug, sqlx::Type)] -#[sqlx(postgres(oid = 25))] +#[sqlx(rename = "text")] #[sqlx(rename_all = "lowercase")] enum Strong { One, @@ -32,7 +32,7 @@ enum Strong { // Records must map to a custom type // Note that all types are types in Postgres // #[derive(PartialEq, Debug, sqlx::Type)] -// #[sqlx(postgres(oid = ?))] +// #[sqlx(rename = "inventory_item")] // struct InventoryItem { // name: String, // supplier_id: Option, diff --git a/tests/postgres-types.rs b/tests/postgres-types.rs index 3977566c..90d844ea 100644 --- a/tests/postgres-types.rs +++ b/tests/postgres-types.rs @@ -1,11 +1,9 @@ extern crate time_ as time; -use std::sync::atomic::{AtomicU32, Ordering}; - use sqlx::decode::Decode; use sqlx::encode::Encode; use sqlx::postgres::types::raw::{PgNumeric, PgNumericSign, PgRecordDecoder, PgRecordEncoder}; -use sqlx::postgres::{PgQueryAs, PgTypeInfo, PgValue}; +use sqlx::postgres::{PgQueryAs, PgRawBuffer, PgTypeInfo, PgValue}; use sqlx::{Cursor, Executor, Postgres, Row, Type}; use sqlx_test::{new, test_prepared_type, test_type}; @@ -440,9 +438,6 @@ async fn test_prepared_structs() -> anyhow::Result<()> { // Setup custom types if needed // - static OID_RECORD_EMPTY: AtomicU32 = AtomicU32::new(0); - static OID_RECORD_1: AtomicU32 = AtomicU32::new(0); - conn.execute( r#" DO $$ BEGIN @@ -455,15 +450,6 @@ END $$; ) .await?; - let type_ids: Vec<(i32,)> = sqlx::query_as( - "SELECT oid::int4 FROM pg_type WHERE typname IN ('_sqlx_record_empty', '_sqlx_record_1')", - ) - .fetch_all(&mut conn) - .await?; - - OID_RECORD_EMPTY.store(type_ids[0].0 as u32, Ordering::SeqCst); - OID_RECORD_1.store(type_ids[1].0 as u32, Ordering::SeqCst); - // // Record of no elements // @@ -472,12 +458,12 @@ END $$; impl Type for RecordEmpty { fn type_info() -> PgTypeInfo { - PgTypeInfo::with_oid(OID_RECORD_EMPTY.load(Ordering::SeqCst)) + PgTypeInfo::with_name("_sqlx_record_empty") } } impl Encode for RecordEmpty { - fn encode(&self, buf: &mut Vec) { + fn encode(&self, buf: &mut PgRawBuffer) { PgRecordEncoder::new(buf).finish(); } } @@ -504,12 +490,12 @@ END $$; impl Type for Record1 { fn type_info() -> PgTypeInfo { - PgTypeInfo::with_oid(OID_RECORD_1.load(Ordering::SeqCst)) + PgTypeInfo::with_name("_sqlx_record_1") } } impl Encode for Record1 { - fn encode(&self, buf: &mut Vec) { + fn encode(&self, buf: &mut PgRawBuffer) { PgRecordEncoder::new(buf).encode(self._1).finish(); } }