From c9f3e1adca9637818561bd4a786bde7de1245d3c Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Fri, 12 Jun 2020 15:20:24 -0700 Subject: [PATCH] feat(postgres): add support for built-in range types and allow derives to handle custom range types Co-authored-by: Caio --- sqlx-core/src/encode.rs | 2 + sqlx-core/src/error.rs | 3 +- sqlx-core/src/mysql/type_info.rs | 9 + sqlx-core/src/postgres/arguments.rs | 40 +- sqlx-core/src/postgres/connection/describe.rs | 38 +- sqlx-core/src/postgres/connection/executor.rs | 30 +- sqlx-core/src/postgres/type_info.rs | 50 +- sqlx-core/src/postgres/types/array.rs | 37 +- .../src/postgres/types/{num.rs => int.rs} | 0 sqlx-core/src/postgres/types/mod.rs | 14 +- sqlx-core/src/postgres/types/range.rs | 530 ++++++++++++++++++ sqlx-core/src/postgres/types/ranges.rs | 87 --- .../src/postgres/types/ranges/pg_range.rs | 385 ------------- .../src/postgres/types/ranges/pg_ranges.rs | 84 --- sqlx-core/src/postgres/types/record.rs | 36 +- sqlx-core/src/postgres/value.rs | 38 +- sqlx-macros/src/derives/attributes.rs | 17 +- sqlx-macros/src/derives/decode.rs | 105 +++- sqlx-macros/src/derives/encode.rs | 59 +- sqlx-macros/src/derives/type.rs | 57 +- src/lib.rs | 14 +- tests/postgres/derives.rs | 86 ++- tests/postgres/setup.sql | 6 + tests/postgres/types.rs | 56 +- 24 files changed, 922 insertions(+), 861 deletions(-) rename sqlx-core/src/postgres/types/{num.rs => int.rs} (100%) create mode 100644 sqlx-core/src/postgres/types/range.rs delete mode 100644 sqlx-core/src/postgres/types/ranges.rs delete mode 100644 sqlx-core/src/postgres/types/ranges/pg_range.rs delete mode 100644 sqlx-core/src/postgres/types/ranges/pg_ranges.rs diff --git a/sqlx-core/src/encode.rs b/sqlx-core/src/encode.rs index 7a8f2e2c..c62aa9d1 100644 --- a/sqlx-core/src/encode.rs +++ b/sqlx-core/src/encode.rs @@ -33,6 +33,8 @@ pub trait Encode<'q, DB: Database> { fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> IsNull; fn produces(&self) -> Option { + // `produces` is inherently a hook to allow database drivers to produce value-dependent + // type information; if the driver doesn't need this, it can leave this as `None` None } diff --git a/sqlx-core/src/error.rs b/sqlx-core/src/error.rs index 3a898fb1..eaecf921 100644 --- a/sqlx-core/src/error.rs +++ b/sqlx-core/src/error.rs @@ -11,7 +11,8 @@ use crate::database::Database; pub type Result = StdResult; // Convenience type alias for usage within SQLx. -pub type BoxDynError = Box; +// Do not make this type public. +pub(crate) type BoxDynError = Box; /// An unexpected `NULL` was encountered during decoding. /// diff --git a/sqlx-core/src/mysql/type_info.rs b/sqlx-core/src/mysql/type_info.rs index cd5c3ed1..92e87fc5 100644 --- a/sqlx-core/src/mysql/type_info.rs +++ b/sqlx-core/src/mysql/type_info.rs @@ -21,6 +21,15 @@ impl MySqlTypeInfo { } } + #[doc(hidden)] + pub const fn __enum() -> Self { + Self { + r#type: ColumnType::Enum, + flags: ColumnFlags::BINARY, + char_set: 63, + } + } + #[doc(hidden)] pub fn __type_feature_gate(&self) -> Option<&'static str> { match self.r#type { diff --git a/sqlx-core/src/postgres/arguments.rs b/sqlx-core/src/postgres/arguments.rs index d3df15f5..ed379d14 100644 --- a/sqlx-core/src/postgres/arguments.rs +++ b/sqlx-core/src/postgres/arguments.rs @@ -47,26 +47,34 @@ impl<'q> Arguments<'q> for PgArguments { self.types .push(value.produces().unwrap_or_else(T::type_info)); - // reserve space to write the prefixed length of the value - let offset = self.buffer.len(); - self.buffer.extend(&[0; 4]); - // encode the value into our buffer - let len = if let IsNull::No = value.encode(&mut self.buffer) { - (self.buffer.len() - offset - 4) as i32 - } else { - // Write a -1 to indicate NULL - // NOTE: It is illegal for [encode] to write any data - debug_assert_eq!(self.buffer.len(), offset + 4); - -1_i32 - }; - - // write the len to the beginning of the value - self.buffer.buffer[offset..(offset + 4)].copy_from_slice(&len.to_be_bytes()); + self.buffer.encode(value); } } impl PgArgumentBuffer { + pub(crate) fn encode<'q, T>(&mut self, value: T) + where + T: Encode<'q, Postgres>, + { + // reserve space to write the prefixed length of the value + let offset = self.len(); + self.extend(&[0; 4]); + + // encode the value into our buffer + let len = if let IsNull::No = value.encode(self) { + (self.len() - offset - 4) as i32 + } else { + // Write a -1 to indicate NULL + // NOTE: It is illegal for [encode] to write any data + debug_assert_eq!(self.len(), offset + 4); + -1_i32 + }; + + // write the len to the beginning of the value + self[offset..(offset + 4)].copy_from_slice(&len.to_be_bytes()); + } + // Extends the inner buffer by enough space to have an OID // Remembers where the OID goes and type name for the OID pub(crate) fn push_type_hole(&mut self, type_name: &UStr) { @@ -81,7 +89,7 @@ impl PgArgumentBuffer { pub(crate) async fn patch_type_holes(&mut self, conn: &mut PgConnection) -> Result<(), Error> { for (offset, name) in &self.type_holes { let oid = conn.fetch_type_id_by_name(&*name).await?; - self.buffer[*offset..].copy_from_slice(&oid.to_be_bytes()); + self.buffer[*offset..(*offset + 4)].copy_from_slice(&oid.to_be_bytes()); } Ok(()) diff --git a/sqlx-core/src/postgres/connection/describe.rs b/sqlx-core/src/postgres/connection/describe.rs index 8d65101b..b9aa9651 100644 --- a/sqlx-core/src/postgres/connection/describe.rs +++ b/sqlx-core/src/postgres/connection/describe.rs @@ -209,27 +209,31 @@ ORDER BY attnum }) } - async fn fetch_range_by_oid(&mut self, oid: u32, name: String) -> Result { - let _: i32 = query_scalar( - r#" -SELECT 1 + fn fetch_range_by_oid( + &mut self, + oid: u32, + name: String, + ) -> BoxFuture<'_, Result> { + Box::pin(async move { + let element_oid: u32 = query_scalar( + r#" +SELECT rngsubtype FROM pg_catalog.pg_range WHERE rngtypid = $1 - "#, - ) - .bind(oid) - .fetch_one(self) - .await?; + "#, + ) + .bind(oid) + .fetch_one(&mut *self) + .await?; - let pg_type = PgType::try_from_oid(oid).ok_or_else(|| { - err_protocol!("Trying to retrieve a DB type that doesn't exist in SQLx") - })?; + let element = self.maybe_fetch_type_info_by_oid(element_oid, true).await?; - Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { - kind: PgTypeKind::Range(PgTypeInfo(pg_type)), - name: name.into(), - oid, - })))) + Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { + kind: PgTypeKind::Range(element), + name: name.into(), + oid, + })))) + }) } pub(crate) async fn fetch_type_id_by_name(&mut self, name: &str) -> Result { diff --git a/sqlx-core/src/postgres/connection/executor.rs b/sqlx-core/src/postgres/connection/executor.rs index 2c1b5d85..5a74963b 100644 --- a/sqlx-core/src/postgres/connection/executor.rs +++ b/sqlx-core/src/postgres/connection/executor.rs @@ -28,15 +28,22 @@ async fn prepare( // additional queries here to get any missing OIDs let mut param_types = Vec::with_capacity(arguments.types.len()); + let mut has_fetched = false; for ty in &arguments.types { param_types.push(if let PgType::DeclareWithName(name) = &ty.0 { + has_fetched = true; conn.fetch_type_id_by_name(name).await? } else { ty.0.oid() }); } + // flush and wait until we are re-ready + if has_fetched { + conn.wait_until_ready().await?; + } + // next we send the PARSE command to the server conn.stream.write(Parse { param_types: &*param_types, @@ -111,6 +118,18 @@ impl PgConnection { // patch holes created during encoding arguments.buffer.patch_type_holes(self).await?; + // describe the statement and, again, ask the server to immediately respond + // we need to fully realize the types + self.stream.write(message::Describe::Statement(statement)); + self.stream.write(message::Flush); + self.stream.flush().await?; + + let _ = recv_desc_params(self).await?; + let rows = recv_desc_rows(self).await?; + + self.handle_row_description(rows, true).await?; + self.wait_until_ready().await?; + // bind to attach the arguments to the statement and create a portal self.stream.write(Bind { portal: None, @@ -121,17 +140,6 @@ impl PgConnection { result_formats: &[PgValueFormat::Binary], }); - // describe the portal and, again, ask the server to immediately respond - // we need to fully realize the types - self.stream.write(message::Describe::UnnamedPortal); - self.stream.write(Flush); - self.stream.flush().await?; - - let _ = self.stream.recv_expect(MessageFormat::BindComplete).await?; - - let rows = recv_desc_rows(self).await?; - self.handle_row_description(rows, true).await?; - // executes the portal up to the passed limit // the protocol-level limit acts nearly identically to the `LIMIT` in SQL self.stream.write(message::Execute { diff --git a/sqlx-core/src/postgres/type_info.rs b/sqlx-core/src/postgres/type_info.rs index 7722931b..3047e1db 100644 --- a/sqlx-core/src/postgres/type_info.rs +++ b/sqlx-core/src/postgres/type_info.rs @@ -11,7 +11,7 @@ use crate::type_info::TypeInfo; #[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] pub struct PgTypeInfo(pub(crate) PgType); -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone)] #[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] #[repr(u32)] pub(crate) enum PgType { @@ -197,7 +197,12 @@ impl PgTypeInfo { Self(PgType::DeclareWithName(UStr::Static(name))) } - pub(crate) const fn with_oid(oid: u32) -> Self { + /// Create a `PgTypeInfo` from an OID. + /// + /// Note that the OID for a type is very dependent on the environment. If you only ever use + /// one database or if this is an unhandled build-in type, you should be fine. Otherwise, + /// you will be better served using [`with_name`](#method.with_name). + pub const fn with_oid(oid: u32) -> Self { Self(PgType::DeclareWithOid(oid)) } } @@ -308,7 +313,14 @@ impl PgType { } pub(crate) fn oid(&self) -> u32 { - match self { + match self.try_oid() { + Some(oid) => oid, + None => unreachable!("(bug) use of unresolved type declaration [oid]"), + } + } + + pub(crate) fn try_oid(&self) -> Option { + Some(match self { PgType::Bool => 16, PgType::Bytea => 17, PgType::Char => 18, @@ -400,8 +412,10 @@ impl PgType { PgType::Custom(ty) => ty.oid, PgType::DeclareWithOid(oid) => *oid, - PgType::DeclareWithName(_) => unreachable!("(bug) use of unresolved type declaration"), - } + PgType::DeclareWithName(_) => { + return None; + } + }) } pub(crate) fn name(&self) -> &str { @@ -576,24 +590,24 @@ impl PgType { PgType::UuidArray => &PgTypeKind::Array(PgTypeInfo(PgType::Uuid)), PgType::Jsonb => &PgTypeKind::Simple, PgType::JsonbArray => &PgTypeKind::Array(PgTypeInfo(PgType::Jsonb)), - PgType::Int4Range => &PgTypeKind::Simple, + PgType::Int4Range => &PgTypeKind::Range(PgTypeInfo::INT4), PgType::Int4RangeArray => &PgTypeKind::Array(PgTypeInfo(PgType::Int4Range)), - PgType::NumRange => &PgTypeKind::Simple, + PgType::NumRange => &PgTypeKind::Range(PgTypeInfo::NUMERIC), PgType::NumRangeArray => &PgTypeKind::Array(PgTypeInfo(PgType::NumRange)), - PgType::TsRange => &PgTypeKind::Simple, + PgType::TsRange => &PgTypeKind::Range(PgTypeInfo::TIMESTAMP), PgType::TsRangeArray => &PgTypeKind::Array(PgTypeInfo(PgType::TsRange)), - PgType::TstzRange => &PgTypeKind::Simple, + PgType::TstzRange => &PgTypeKind::Range(PgTypeInfo::TIMESTAMPTZ), PgType::TstzRangeArray => &PgTypeKind::Array(PgTypeInfo(PgType::TstzRange)), - PgType::DateRange => &PgTypeKind::Simple, + PgType::DateRange => &PgTypeKind::Range(PgTypeInfo::DATE), PgType::DateRangeArray => &PgTypeKind::Array(PgTypeInfo(PgType::DateRange)), - PgType::Int8Range => &PgTypeKind::Simple, + PgType::Int8Range => &PgTypeKind::Range(PgTypeInfo::INT8), PgType::Int8RangeArray => &PgTypeKind::Array(PgTypeInfo(PgType::Int8Range)), PgType::Jsonpath => &PgTypeKind::Simple, PgType::JsonpathArray => &PgTypeKind::Array(PgTypeInfo(PgType::Jsonpath)), PgType::Custom(ty) => &ty.kind, PgType::DeclareWithOid(_) | PgType::DeclareWithName(_) => { - unreachable!("(bug) use of unresolved type declaration") + unreachable!("(bug) use of unresolved type declaration [kind]") } } } @@ -817,3 +831,15 @@ impl Display for PgTypeInfo { f.pad(self.0.name()) } } + +impl PartialEq for PgType { + fn eq(&self, other: &PgType) -> bool { + if let (Some(a), Some(b)) = (self.try_oid(), other.try_oid()) { + // If there are OIDs available, use OIDs to perform a direct match + a == b + } else { + // Otherwise, perform a match on the name + self.name().eq_ignore_ascii_case(other.name()) + } + } +} diff --git a/sqlx-core/src/postgres/types/array.rs b/sqlx-core/src/postgres/types/array.rs index 97674bc0..f86a4b08 100644 --- a/sqlx-core/src/postgres/types/array.rs +++ b/sqlx-core/src/postgres/types/array.rs @@ -63,22 +63,7 @@ where buf.extend(&1_i32.to_be_bytes()); // lower bound for element in self.iter() { - // allocate space for the length of the encoded element - let el_len_offset = buf.len(); - buf.extend(&0_i32.to_be_bytes()); - - let el_start = buf.len(); - - if let IsNull::Yes = element.encode_by_ref(buf) { - // NULL is encoded as -1 for a length - buf[el_len_offset..el_start].copy_from_slice(&(-1_i32).to_be_bytes()); - } else { - let el_end = buf.len(); - let el_len = el_end - el_start; - - // now we can go back and update the length - buf[el_len_offset..el_start].copy_from_slice(&(el_len as i32).to_be_bytes()); - } + buf.encode(element); } IsNull::No @@ -144,23 +129,11 @@ where let mut elements = Vec::with_capacity(len as usize); for _ in 0..len { - let mut element_len = buf.get_i32(); - - let element_val = if element_len == -1 { - element_len = 0; - None - } else { - Some(&buf[..(element_len as usize)]) - }; - - elements.push(T::decode(PgValueRef { - value: element_val, - row: None, - type_info: element_type_info.clone(), + elements.push(T::decode(PgValueRef::get( + &mut buf, format, - })?); - - buf.advance(element_len as usize); + element_type_info.clone(), + ))?) } Ok(elements) diff --git a/sqlx-core/src/postgres/types/num.rs b/sqlx-core/src/postgres/types/int.rs similarity index 100% rename from sqlx-core/src/postgres/types/num.rs rename to sqlx-core/src/postgres/types/int.rs diff --git a/sqlx-core/src/postgres/types/mod.rs b/sqlx-core/src/postgres/types/mod.rs index 0fb95916..ce355c54 100644 --- a/sqlx-core/src/postgres/types/mod.rs +++ b/sqlx-core/src/postgres/types/mod.rs @@ -132,8 +132,8 @@ mod array; mod bool; mod bytes; mod float; -mod num; -mod ranges; +mod int; +mod range; mod record; mod str; mod tuple; @@ -159,7 +159,9 @@ mod json; #[cfg(feature = "ipnetwork")] mod ipnetwork; -pub use { - ranges::{pg_range::PgRange, pg_ranges::*}, - record::{PgRecordDecoder, PgRecordEncoder}, -}; +pub use range::PgRange; + +// used in derive(Type) for `struct` +// but the interface is not considered part of the public API +#[doc(hidden)] +pub use record::{PgRecordDecoder, PgRecordEncoder}; diff --git a/sqlx-core/src/postgres/types/range.rs b/sqlx-core/src/postgres/types/range.rs new file mode 100644 index 00000000..56138982 --- /dev/null +++ b/sqlx-core/src/postgres/types/range.rs @@ -0,0 +1,530 @@ +use std::fmt::{self, Debug, Display, Formatter}; +use std::ops::{Bound, Range, RangeBounds, RangeFrom, RangeInclusive, RangeTo, RangeToInclusive}; + +use bitflags::bitflags; +use bytes::Buf; + +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::postgres::{ + PgArgumentBuffer, PgTypeInfo, PgTypeKind, PgValueFormat, PgValueRef, Postgres, +}; +use crate::types::Type; + +// https://github.com/postgres/postgres/blob/2f48ede080f42b97b594fb14102c82ca1001b80c/src/include/utils/rangetypes.h#L35-L44 +bitflags! { + struct RangeFlags: u8 { + const EMPTY = 0x01; + const LB_INC = 0x02; + const UB_INC = 0x04; + const LB_INF = 0x08; + const UB_INF = 0x10; + const LB_NULL = 0x20; // not used + const UB_NULL = 0x40; // not used + const CONTAIN_EMPTY = 0x80; // internal + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct PgRange { + pub start: Bound, + pub end: Bound, +} + +impl From<[Bound; 2]> for PgRange { + fn from(v: [Bound; 2]) -> Self { + let [start, end] = v; + Self { start, end } + } +} + +impl From<(Bound, Bound)> for PgRange { + fn from(v: (Bound, Bound)) -> Self { + Self { + start: v.0, + end: v.1, + } + } +} + +impl From> for PgRange { + fn from(v: Range) -> Self { + Self { + start: Bound::Included(v.start), + end: Bound::Excluded(v.end), + } + } +} + +impl From> for PgRange { + fn from(v: RangeFrom) -> Self { + Self { + start: Bound::Included(v.start), + end: Bound::Unbounded, + } + } +} + +impl From> for PgRange { + fn from(v: RangeInclusive) -> Self { + let (start, end) = v.into_inner(); + Self { + start: Bound::Included(start), + end: Bound::Included(end), + } + } +} + +impl From> for PgRange { + fn from(v: RangeTo) -> Self { + Self { + start: Bound::Unbounded, + end: Bound::Excluded(v.end), + } + } +} + +impl From> for PgRange { + fn from(v: RangeToInclusive) -> Self { + Self { + start: Bound::Unbounded, + end: Bound::Included(v.end), + } + } +} + +impl RangeBounds for PgRange { + fn start_bound(&self) -> Bound<&T> { + match self.start { + Bound::Included(ref start) => Bound::Included(start), + Bound::Excluded(ref start) => Bound::Excluded(start), + Bound::Unbounded => Bound::Unbounded, + } + } + + fn end_bound(&self) -> Bound<&T> { + match self.end { + Bound::Included(ref end) => Bound::Included(end), + Bound::Excluded(ref end) => Bound::Excluded(end), + Bound::Unbounded => Bound::Unbounded, + } + } +} + +impl Type for PgRange { + fn type_info() -> PgTypeInfo { + PgTypeInfo::INT4_RANGE + } +} + +impl Type for PgRange { + fn type_info() -> PgTypeInfo { + PgTypeInfo::INT8_RANGE + } +} + +#[cfg(feature = "bigdecimal")] +impl Type for PgRange { + fn type_info() -> PgTypeInfo { + PgTypeInfo::NUM_RANGE + } +} + +#[cfg(feature = "chrono")] +impl Type for PgRange { + fn type_info() -> PgTypeInfo { + PgTypeInfo::DATE_RANGE + } +} + +#[cfg(feature = "chrono")] +impl Type for PgRange { + fn type_info() -> PgTypeInfo { + PgTypeInfo::TS_RANGE + } +} + +#[cfg(feature = "chrono")] +impl Type for PgRange> { + fn type_info() -> PgTypeInfo { + PgTypeInfo::TSTZ_RANGE + } +} + +#[cfg(feature = "time")] +impl Type for PgRange { + fn type_info() -> PgTypeInfo { + PgTypeInfo::DATE_RANGE + } +} + +#[cfg(feature = "time")] +impl Type for PgRange { + fn type_info() -> PgTypeInfo { + PgTypeInfo::TS_RANGE + } +} + +#[cfg(feature = "time")] +impl Type for PgRange { + fn type_info() -> PgTypeInfo { + PgTypeInfo::TSTZ_RANGE + } +} + +impl Type for [PgRange] { + fn type_info() -> PgTypeInfo { + PgTypeInfo::INT4_RANGE_ARRAY + } +} + +impl Type for [PgRange] { + fn type_info() -> PgTypeInfo { + PgTypeInfo::INT8_RANGE_ARRAY + } +} + +#[cfg(feature = "bigdecimal")] +impl Type for [PgRange] { + fn type_info() -> PgTypeInfo { + PgTypeInfo::NUM_RANGE_ARRAY + } +} + +#[cfg(feature = "chrono")] +impl Type for [PgRange] { + fn type_info() -> PgTypeInfo { + PgTypeInfo::DATE_RANGE_ARRAY + } +} + +#[cfg(feature = "chrono")] +impl Type for [PgRange] { + fn type_info() -> PgTypeInfo { + PgTypeInfo::TS_RANGE_ARRAY + } +} + +#[cfg(feature = "chrono")] +impl Type for [PgRange>] { + fn type_info() -> PgTypeInfo { + PgTypeInfo::TSTZ_RANGE_ARRAY + } +} + +#[cfg(feature = "time")] +impl Type for [PgRange] { + fn type_info() -> PgTypeInfo { + PgTypeInfo::DATE_RANGE_ARRAY + } +} + +#[cfg(feature = "time")] +impl Type for [PgRange] { + fn type_info() -> PgTypeInfo { + PgTypeInfo::TS_RANGE_ARRAY + } +} + +#[cfg(feature = "time")] +impl Type for [PgRange] { + fn type_info() -> PgTypeInfo { + PgTypeInfo::TSTZ_RANGE_ARRAY + } +} + +impl Type for Vec> { + fn type_info() -> PgTypeInfo { + PgTypeInfo::INT4_RANGE_ARRAY + } +} + +impl Type for Vec> { + fn type_info() -> PgTypeInfo { + PgTypeInfo::INT8_RANGE_ARRAY + } +} + +#[cfg(feature = "bigdecimal")] +impl Type for Vec> { + fn type_info() -> PgTypeInfo { + PgTypeInfo::NUM_RANGE_ARRAY + } +} + +#[cfg(feature = "chrono")] +impl Type for Vec> { + fn type_info() -> PgTypeInfo { + PgTypeInfo::DATE_RANGE_ARRAY + } +} + +#[cfg(feature = "chrono")] +impl Type for Vec> { + fn type_info() -> PgTypeInfo { + PgTypeInfo::TS_RANGE_ARRAY + } +} + +#[cfg(feature = "chrono")] +impl Type for Vec>> { + fn type_info() -> PgTypeInfo { + PgTypeInfo::TSTZ_RANGE_ARRAY + } +} + +#[cfg(feature = "time")] +impl Type for Vec> { + fn type_info() -> PgTypeInfo { + PgTypeInfo::DATE_RANGE_ARRAY + } +} + +#[cfg(feature = "time")] +impl Type for Vec> { + fn type_info() -> PgTypeInfo { + PgTypeInfo::TS_RANGE_ARRAY + } +} + +#[cfg(feature = "time")] +impl Type for Vec> { + fn type_info() -> PgTypeInfo { + PgTypeInfo::TSTZ_RANGE_ARRAY + } +} + +impl<'q, T> Encode<'q, Postgres> for PgRange +where + T: Encode<'q, Postgres>, +{ + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + // https://github.com/postgres/postgres/blob/2f48ede080f42b97b594fb14102c82ca1001b80c/src/backend/utils/adt/rangetypes.c#L245 + + let mut flags = RangeFlags::empty(); + + flags |= match self.start { + Bound::Included(_) => RangeFlags::LB_INC, + Bound::Unbounded => RangeFlags::LB_INF, + Bound::Excluded(_) => RangeFlags::empty(), + }; + + flags |= match self.end { + Bound::Included(_) => RangeFlags::UB_INC, + Bound::Unbounded => RangeFlags::UB_INF, + Bound::Excluded(_) => RangeFlags::empty(), + }; + + buf.push(flags.bits()); + + if let Bound::Included(v) | Bound::Excluded(v) = &self.start { + buf.encode(v); + } + + if let Bound::Included(v) | Bound::Excluded(v) = &self.end { + buf.encode(v); + } + + // ranges are themselves never null + IsNull::No + } +} + +impl<'r, T> Decode<'r, Postgres> for PgRange +where + T: Type + for<'a> Decode<'a, Postgres>, +{ + fn accepts(ty: &PgTypeInfo) -> bool { + // we require the declared type to be a _range_ with an + // element type that is acceptable + if let PgTypeKind::Range(element) = &ty.0.kind() { + return T::accepts(&element); + } + + false + } + + fn decode(value: PgValueRef<'r>) -> Result { + match value.format { + PgValueFormat::Binary => { + let element_ty = if let PgTypeKind::Range(element) = &value.type_info.0.kind() { + element + } else { + return Err(format!("unexpected non-range type {}", value.type_info).into()); + }; + + let mut buf = value.as_bytes()?; + + let mut start = Bound::Unbounded; + let mut end = Bound::Unbounded; + + let flags = RangeFlags::from_bits_truncate(buf.get_u8()); + + if flags.contains(RangeFlags::EMPTY) { + return Ok(PgRange { start, end }); + } + + if !flags.contains(RangeFlags::LB_INF) { + let value = + T::decode(PgValueRef::get(&mut buf, value.format, element_ty.clone()))?; + + start = if flags.contains(RangeFlags::LB_INC) { + Bound::Included(value) + } else { + Bound::Excluded(value) + }; + } + + if !flags.contains(RangeFlags::UB_INF) { + let value = + T::decode(PgValueRef::get(&mut buf, value.format, element_ty.clone()))?; + + end = if flags.contains(RangeFlags::UB_INC) { + Bound::Included(value) + } else { + Bound::Excluded(value) + }; + } + + Ok(PgRange { start, end }) + } + + PgValueFormat::Text => { + // https://github.com/postgres/postgres/blob/2f48ede080f42b97b594fb14102c82ca1001b80c/src/backend/utils/adt/rangetypes.c#L2046 + + let mut start = None; + let mut end = None; + + let s = value.as_str()?; + + // remember the bounds + let sb = s.as_bytes(); + let lower = sb[0] as char; + let upper = sb[sb.len() - 1] as char; + + // trim the wrapping braces/brackets + let s = &s[1..(s.len() - 1)]; + + let mut chars = s.chars(); + + let mut element = String::new(); + let mut done = false; + let mut quoted = false; + let mut in_quotes = false; + let mut in_escape = false; + let mut prev_ch = '\0'; + let mut count = 0; + + while !done { + element.clear(); + + loop { + match chars.next() { + Some(ch) => { + match ch { + _ if in_escape => { + element.push(ch); + in_escape = false; + } + + '"' if in_quotes => { + in_quotes = false; + } + + '"' => { + in_quotes = true; + quoted = true; + + if prev_ch == '"' { + element.push('"') + } + } + + '\\' if !in_escape => { + in_escape = true; + } + + ',' if !in_quotes => break, + + _ => { + element.push(ch); + } + } + prev_ch = ch; + } + + None => { + done = true; + break; + } + } + } + + count += 1; + if !(element.is_empty() && !quoted) { + let value = Some(T::decode(PgValueRef { + type_info: T::type_info(), + format: PgValueFormat::Text, + value: Some(element.as_bytes()), + row: None, + })?); + + if count == 1 { + start = value; + } else if count == 2 { + end = value; + } else { + return Err("more than 2 elements found in a range".into()); + } + } + } + + let start = parse_bound(lower, start)?; + let end = parse_bound(upper, end)?; + + Ok(PgRange { start, end }) + } + } + } +} + +fn parse_bound(ch: char, value: Option) -> Result, BoxDynError> { + Ok(if let Some(value) = value { + match ch { + '(' | ')' => Bound::Excluded(value), + '[' | ']' => Bound::Included(value), + + _ => { + return Err(format!( + "expected `(`, ')', '[', or `]` but found `{}` for range literal", + ch + ) + .into()); + } + } + } else { + Bound::Unbounded + }) +} + +impl Display for PgRange +where + T: Display, +{ + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match &self.start { + Bound::Unbounded => f.write_str("(,")?, + Bound::Excluded(v) => write!(f, "({},", v)?, + Bound::Included(v) => write!(f, "[{},", v)?, + } + + match &self.end { + Bound::Unbounded => f.write_str(")")?, + Bound::Excluded(v) => write!(f, "{})", v)?, + Bound::Included(v) => write!(f, "{}]", v)?, + } + + Ok(()) + } +} diff --git a/sqlx-core/src/postgres/types/ranges.rs b/sqlx-core/src/postgres/types/ranges.rs deleted file mode 100644 index 92eb151d..00000000 --- a/sqlx-core/src/postgres/types/ranges.rs +++ /dev/null @@ -1,87 +0,0 @@ -pub(crate) mod pg_range; -pub(crate) mod pg_ranges; - -use crate::{ - decode::Decode, - encode::{Encode, IsNull}, - postgres::{types::PgRange, PgArgumentBuffer, PgTypeInfo, PgValueRef, Postgres}, - types::Type, -}; -use core::{ - convert::TryInto, - ops::{Range, RangeFrom, RangeInclusive, RangeTo, RangeToInclusive}, -}; - -macro_rules! impl_range { - ($range:ident) => { - impl<'a, T> Decode<'a, Postgres> for $range - where - T: for<'b> Decode<'b, Postgres> + Type + 'a, - { - fn accepts(ty: &PgTypeInfo) -> bool { - as Decode<'_, Postgres>>::accepts(ty) - } - - fn decode(value: PgValueRef<'a>) -> Result<$range, crate::error::BoxDynError> { - let bounds: PgRange = Decode::::decode(value)?; - let rslt = bounds.try_into()?; - Ok(rslt) - } - } - - impl<'a, T> Encode<'a, Postgres> for $range - where - T: Clone + for<'b> Encode<'b, Postgres> + 'a, - { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { - as Encode<'_, Postgres>>::encode(self.clone().into(), buf) - } - } - }; -} - -impl_range!(Range); -impl_range!(RangeFrom); -impl_range!(RangeInclusive); -impl_range!(RangeTo); -impl_range!(RangeToInclusive); - -#[test] -fn test_decode_str_bounds() { - use crate::postgres::type_info::PgType; - - const EXC1: Bound = Bound::Excluded(1); - const EXC2: Bound = Bound::Excluded(2); - const INC1: Bound = Bound::Included(1); - const INC2: Bound = Bound::Included(2); - const UNB: Bound = Bound::Unbounded; - - let check = |s: &str, range_cmp: [Bound; 2]| { - let pg_value = PgValueRef { - type_info: PgTypeInfo(PgType::Int4Range), - format: PgValueFormat::Text, - value: Some(s.as_bytes()), - row: None, - }; - let range: PgRange = Decode::::decode(pg_value).unwrap(); - assert_eq!(Into::<[Bound; 2]>::into(range), range_cmp); - }; - - check("(,)", [UNB, UNB]); - check("(,]", [UNB, UNB]); - check("(,2)", [UNB, EXC2]); - check("(,2]", [UNB, INC2]); - check("(1,)", [EXC1, UNB]); - check("(1,]", [EXC1, UNB]); - check("(1,2)", [EXC1, EXC2]); - check("(1,2]", [EXC1, INC2]); - - check("[,)", [UNB, UNB]); - check("[,]", [UNB, UNB]); - check("[,2)", [UNB, EXC2]); - check("[,2]", [UNB, INC2]); - check("[1,)", [INC1, UNB]); - check("[1,]", [INC1, UNB]); - check("[1,2)", [INC1, EXC2]); - check("[1,2]", [INC1, INC2]); -} diff --git a/sqlx-core/src/postgres/types/ranges/pg_range.rs b/sqlx-core/src/postgres/types/ranges/pg_range.rs deleted file mode 100644 index fb5402e8..00000000 --- a/sqlx-core/src/postgres/types/ranges/pg_range.rs +++ /dev/null @@ -1,385 +0,0 @@ -use crate::{ - decode::Decode, - encode::{Encode, IsNull}, - postgres::{PgArgumentBuffer, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}, - types::Type, -}; -use byteorder::{NetworkEndian, ReadBytesExt, WriteBytesExt}; -use core::{ - convert::TryFrom, - ops::{Bound, Range, RangeBounds, RangeFrom, RangeInclusive, RangeTo, RangeToInclusive}, -}; - -bitflags::bitflags! { - struct RangeFlags: u8 { - const EMPTY = 0x01; - const LB_INC = 0x02; - const UB_INC = 0x04; - const LB_INF = 0x08; - const UB_INF = 0x10; - const LB_NULL = 0x20; - const UB_NULL = 0x40; - const CONTAIN_EMPTY = 0x80; - } -} - -#[derive(Clone, Debug, Hash, PartialEq, Eq)] -pub struct PgRange { - pub start: Bound, - pub end: Bound, -} - -impl PgRange { - pub fn new(start: Bound, end: Bound) -> Self { - Self { start, end } - } -} - -impl<'a, T> Decode<'a, Postgres> for PgRange -where - T: for<'b> Decode<'b, Postgres> + Type + 'a, -{ - fn accepts(ty: &PgTypeInfo) -> bool { - [ - PgTypeInfo::INT4_RANGE, - PgTypeInfo::NUM_RANGE, - PgTypeInfo::TS_RANGE, - PgTypeInfo::TSTZ_RANGE, - PgTypeInfo::DATE_RANGE, - PgTypeInfo::INT8_RANGE, - ] - .contains(ty) - } - - fn decode(value: PgValueRef<'a>) -> Result, crate::error::BoxDynError> { - match value.format() { - PgValueFormat::Binary => { - decode_binary(value.as_bytes()?, value.format, value.type_info) - } - PgValueFormat::Text => decode_str(value.as_str()?, value.format(), value.type_info), - } - } -} - -impl<'a, T> Encode<'a, Postgres> for PgRange -where - T: for<'b> Encode<'b, Postgres> + 'a, -{ - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { - let mut flags = match self.start { - Bound::Included(_) => RangeFlags::LB_INC, - Bound::Excluded(_) => RangeFlags::empty(), - Bound::Unbounded => RangeFlags::LB_INF, - }; - - flags |= match self.end { - Bound::Included(_) => RangeFlags::UB_INC, - Bound::Excluded(_) => RangeFlags::empty(), - Bound::Unbounded => RangeFlags::UB_INF, - }; - - buf.write_u8(flags.bits()).unwrap(); - - let mut write = |bound: &Bound| -> IsNull { - match bound { - Bound::Included(ref value) | Bound::Excluded(ref value) => { - buf.write_u32::(0).unwrap(); - let prev = buf.len(); - if let IsNull::Yes = Encode::::encode(value, buf) { - return IsNull::Yes; - } - let len = buf.len() - prev; - buf[prev - 4..prev].copy_from_slice(&(len as u32).to_be_bytes()); - } - Bound::Unbounded => {} - } - IsNull::No - }; - - if let IsNull::Yes = write(&self.start) { - return IsNull::Yes; - } - write(&self.end) - } -} - -impl From<[Bound; 2]> for PgRange { - fn from(from: [Bound; 2]) -> Self { - let [start, end] = from; - Self { start, end } - } -} - -impl From<(Bound, Bound)> for PgRange { - fn from(from: (Bound, Bound)) -> Self { - Self { - start: from.0, - end: from.1, - } - } -} - -impl From> for [Bound; 2] { - fn from(from: PgRange) -> Self { - [from.start, from.end] - } -} - -impl From> for (Bound, Bound) { - fn from(from: PgRange) -> Self { - (from.start, from.end) - } -} - -impl From> for PgRange { - fn from(from: Range) -> Self { - Self { - start: Bound::Included(from.start), - end: Bound::Excluded(from.end), - } - } -} - -impl From> for PgRange { - fn from(from: RangeFrom) -> Self { - Self { - start: Bound::Included(from.start), - end: Bound::Unbounded, - } - } -} - -impl From> for PgRange { - fn from(from: RangeInclusive) -> Self { - let (start, end) = from.into_inner(); - Self { - start: Bound::Included(start), - end: Bound::Excluded(end), - } - } -} - -impl From> for PgRange { - fn from(from: RangeTo) -> Self { - Self { - start: Bound::Unbounded, - end: Bound::Excluded(from.end), - } - } -} - -impl From> for PgRange { - fn from(from: RangeToInclusive) -> Self { - Self { - start: Bound::Unbounded, - end: Bound::Included(from.end), - } - } -} - -impl RangeBounds for PgRange { - fn start_bound(&self) -> Bound<&T> { - match &self.start { - Bound::Included(ref start) => Bound::Included(start), - Bound::Excluded(ref start) => Bound::Excluded(start), - Bound::Unbounded => Bound::Unbounded, - } - } - - fn end_bound(&self) -> Bound<&T> { - match &self.end { - Bound::Included(ref end) => Bound::Included(end), - Bound::Excluded(ref end) => Bound::Excluded(end), - Bound::Unbounded => Bound::Unbounded, - } - } -} - -impl TryFrom> for Range { - type Error = crate::error::Error; - - fn try_from(from: PgRange) -> crate::error::Result { - let err_msg = "Invalid data for core::ops::Range"; - let start = included(from.start, err_msg)?; - let end = excluded(from.end, err_msg)?; - Ok(start..end) - } -} - -impl TryFrom> for RangeFrom { - type Error = crate::error::Error; - - fn try_from(from: PgRange) -> crate::error::Result { - let err_msg = "Invalid data for core::ops::RangeFrom"; - let start = included(from.start, err_msg)?; - unbounded(from.end, err_msg)?; - Ok(start..) - } -} - -impl TryFrom> for RangeInclusive { - type Error = crate::error::Error; - - fn try_from(from: PgRange) -> crate::error::Result { - let err_msg = "Invalid data for core::ops::RangeInclusive"; - let start = included(from.start, err_msg)?; - let end = included(from.end, err_msg)?; - Ok(start..=end) - } -} - -impl TryFrom> for RangeTo { - type Error = crate::error::Error; - - fn try_from(from: PgRange) -> crate::error::Result { - let err_msg = "Invalid data for core::ops::RangeTo"; - unbounded(from.start, err_msg)?; - let end = excluded(from.end, err_msg)?; - Ok(..end) - } -} - -impl TryFrom> for RangeToInclusive { - type Error = crate::error::Error; - - fn try_from(from: PgRange) -> crate::error::Result { - let err_msg = "Invalid data for core::ops::RangeToInclusive"; - unbounded(from.start, err_msg)?; - let end = included(from.end, err_msg)?; - Ok(..=end) - } -} - -fn decode_binary<'r, T>( - mut bytes: &[u8], - format: PgValueFormat, - type_info: PgTypeInfo, -) -> Result, crate::error::BoxDynError> -where - T: for<'rec> Decode<'rec, Postgres> + 'r, -{ - let flags: RangeFlags = RangeFlags::from_bits_truncate(bytes.read_u8()?); - let mut start_value = Bound::Unbounded; - let mut end_value = Bound::Unbounded; - - if flags.contains(RangeFlags::EMPTY) { - return Ok(PgRange { - start: start_value, - end: end_value, - }); - } - - if !flags.contains(RangeFlags::LB_INF) { - let elem_size = bytes.read_i32::()?; - let (elem_bytes, new_bytes) = bytes.split_at(elem_size as usize); - bytes = new_bytes; - let value = T::decode(PgValueRef { - type_info: type_info.clone(), - format, - value: Some(elem_bytes), - row: None, - })?; - - start_value = if flags.contains(RangeFlags::LB_INC) { - Bound::Included(value) - } else { - Bound::Excluded(value) - }; - } - - if !flags.contains(RangeFlags::UB_INF) { - bytes.read_i32::()?; - let value = T::decode(PgValueRef { - type_info, - format, - value: Some(bytes), - row: None, - })?; - - end_value = if flags.contains(RangeFlags::UB_INC) { - Bound::Included(value) - } else { - Bound::Excluded(value) - }; - } - - Ok(PgRange { - start: start_value, - end: end_value, - }) -} - -fn decode_str<'r, T>( - s: &str, - format: PgValueFormat, - type_info: PgTypeInfo, -) -> Result, crate::error::BoxDynError> -where - T: for<'rec> Decode<'rec, Postgres> + 'r, -{ - let err = || crate::error::Error::Decode("Invalid PostgreSQL range string".into()); - - let value = - |bound: &str, delim, bounds: [&str; 2]| -> Result, crate::error::BoxDynError> { - if bound.len() == 0 { - return Ok(Bound::Unbounded); - } - let bound_value = T::decode(PgValueRef { - type_info: type_info.clone(), - format, - value: Some(bound.as_bytes()), - row: None, - })?; - if delim == bounds[0] { - Ok(Bound::Excluded(bound_value)) - } else if delim == bounds[1] { - Ok(Bound::Included(bound_value)) - } else { - Err(Box::new(err())) - } - }; - - let mut parts = s.split(','); - let start_str = parts.next().ok_or_else(err)?; - let start_value = value( - start_str.get(1..).ok_or_else(err)?, - start_str.get(0..1).ok_or_else(err)?, - ["(", "["], - )?; - let end_str = parts.next().ok_or_else(err)?; - let last_char_idx = end_str.len() - 1; - let end_value = value( - end_str.get(..last_char_idx).ok_or_else(err)?, - end_str.get(last_char_idx..).ok_or_else(err)?, - [")", "]"], - )?; - - Ok(PgRange { - start: start_value, - end: end_value, - }) -} - -fn excluded(b: Bound, err_msg: &str) -> crate::error::Result { - if let Bound::Excluded(rslt) = b { - Ok(rslt) - } else { - Err(crate::error::Error::Decode(err_msg.into())) - } -} - -fn included(b: Bound, err_msg: &str) -> crate::error::Result { - if let Bound::Included(rslt) = b { - Ok(rslt) - } else { - Err(crate::error::Error::Decode(err_msg.into())) - } -} - -fn unbounded(b: Bound, err_msg: &str) -> crate::error::Result<()> { - if matches!(b, Bound::Unbounded) { - Ok(()) - } else { - Err(crate::error::Error::Decode(err_msg.into())) - } -} diff --git a/sqlx-core/src/postgres/types/ranges/pg_ranges.rs b/sqlx-core/src/postgres/types/ranges/pg_ranges.rs deleted file mode 100644 index 0f436cd7..00000000 --- a/sqlx-core/src/postgres/types/ranges/pg_ranges.rs +++ /dev/null @@ -1,84 +0,0 @@ -use crate::{ - decode::Decode, - encode::{Encode, IsNull}, - postgres::{ - types::ranges::pg_range::PgRange, PgArgumentBuffer, PgTypeInfo, PgValueRef, Postgres, - }, - types::Type, -}; - -macro_rules! impl_pg_range { - ($range_name:ident, $type_info:expr, $type_info_array:expr, $range_type:ty) => { - #[derive(Clone, Debug, Hash, PartialEq, Eq)] - #[repr(transparent)] - pub struct $range_name(pub PgRange<$range_type>); - - impl<'a> Decode<'a, Postgres> for $range_name { - fn accepts(ty: &PgTypeInfo) -> bool { - as Decode<'_, Postgres>>::accepts(ty) - } - - fn decode(value: PgValueRef<'a>) -> Result<$range_name, crate::error::BoxDynError> { - Ok(Self(Decode::::decode(value)?)) - } - } - - impl<'a> Encode<'a, Postgres> for $range_name { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { - as Encode<'_, Postgres>>::encode_by_ref(&self.0, buf) - } - } - - impl Type for $range_name { - fn type_info() -> PgTypeInfo { - $type_info - } - } - - impl Type for [$range_name] { - fn type_info() -> PgTypeInfo { - $type_info_array - } - } - - impl Type for Vec<$range_name> { - fn type_info() -> PgTypeInfo { - $type_info_array - } - } - }; -} - -impl_pg_range!( - Int4Range, - PgTypeInfo::INT4_RANGE, - PgTypeInfo::INT4_RANGE_ARRAY, - i32 -); -#[cfg(feature = "bigdecimal")] -impl_pg_range!( - NumRange, - PgTypeInfo::NUM_RANGE, - PgTypeInfo::NUM_RANGE_ARRAY, - bigdecimal::BigDecimal -); -#[cfg(feature = "chrono")] -impl_pg_range!( - TsRange, - PgTypeInfo::TS_RANGE, - PgTypeInfo::TS_RANGE_ARRAY, - chrono::NaiveDateTime -); -#[cfg(feature = "chrono")] -impl_pg_range!( - DateRange, - PgTypeInfo::DATE_RANGE, - PgTypeInfo::DATE_RANGE_ARRAY, - chrono::NaiveDate -); -impl_pg_range!( - Int8Range, - PgTypeInfo::INT8_RANGE, - PgTypeInfo::INT8_RANGE_ARRAY, - i64 -); diff --git a/sqlx-core/src/postgres/types/record.rs b/sqlx-core/src/postgres/types/record.rs index 4ed924ee..c059c09e 100644 --- a/sqlx-core/src/postgres/types/record.rs +++ b/sqlx-core/src/postgres/types/record.rs @@ -1,7 +1,7 @@ use bytes::Buf; use crate::decode::Decode; -use crate::encode::{Encode, IsNull}; +use crate::encode::Encode; use crate::error::{mismatched_types, BoxDynError}; use crate::postgres::type_info::PgType; use crate::postgres::{ @@ -30,7 +30,7 @@ impl<'a> PgRecordEncoder<'a> { #[doc(hidden)] pub fn finish(&mut self) { // fill in the record length - self.buf[self.off..].copy_from_slice(&self.num.to_be_bytes()); + self.buf[self.off..(self.off + 4)].copy_from_slice(&self.num.to_be_bytes()); } #[doc(hidden)] @@ -50,16 +50,8 @@ impl<'a> PgRecordEncoder<'a> { self.buf.extend(&ty.0.oid().to_be_bytes()); } - let offset = self.buf.len(); - self.buf.extend(&(0_u32).to_be_bytes()); - - let size = if let IsNull::Yes = value.encode(self.buf) { - -1 - } else { - (self.buf.len() - offset + 4) as i32 - }; - - self.buf[offset..].copy_from_slice(&size.to_be_bytes()); + self.buf.encode(value); + self.num += 1; self } @@ -133,6 +125,8 @@ impl<'r> PgRecordDecoder<'r> { } }; + self.ind += 1; + if let Some(ty) = &element_type_opt { if !T::accepts(ty) { return Err(mismatched_types::(&T::type_info(), ty)); @@ -142,23 +136,7 @@ impl<'r> PgRecordDecoder<'r> { let element_type = element_type_opt.unwrap_or_else(|| PgTypeInfo::with_oid(element_type_oid)); - let mut element_len = self.buf.get_i32(); - let element_buf = if element_len < 0 { - element_len = 0; - None - } else { - Some(&self.buf[..(element_len as usize)]) - }; - - self.buf.advance(element_len as usize); - self.ind += 1; - - T::decode(PgValueRef { - type_info: element_type, - format: self.fmt, - value: element_buf, - row: None, - }) + T::decode(PgValueRef::get(&mut self.buf, self.fmt, element_type)) } PgValueFormat::Text => { diff --git a/sqlx-core/src/postgres/value.rs b/sqlx-core/src/postgres/value.rs index 1898cde3..00a34ac2 100644 --- a/sqlx-core/src/postgres/value.rs +++ b/sqlx-core/src/postgres/value.rs @@ -1,7 +1,7 @@ use std::borrow::Cow; use std::str::from_utf8; -use bytes::Bytes; +use bytes::{Buf, Bytes}; use crate::error::{BoxDynError, UnexpectedNullError}; use crate::postgres::{PgTypeInfo, Postgres}; @@ -32,6 +32,26 @@ pub struct PgValue { } impl<'r> PgValueRef<'r> { + pub(crate) fn get(buf: &mut &'r [u8], format: PgValueFormat, ty: PgTypeInfo) -> Self { + let mut element_len = buf.get_i32(); + + let element_val = if element_len == -1 { + element_len = 0; + None + } else { + Some(&buf[..(element_len as usize)]) + }; + + buf.advance(element_len as usize); + + PgValueRef { + value: element_val, + row: None, + type_info: ty, + format, + } + } + pub(crate) fn format(&self) -> PgValueFormat { self.format } @@ -62,7 +82,13 @@ impl Value for PgValue { } fn type_info(&self) -> Option> { - Some(Cow::Borrowed(&self.type_info)) + if self.format == PgValueFormat::Text { + // For TEXT encoding the type defined on the value is unreliable + // We don't even bother to return it so type checking is implicitly opted-out + None + } else { + Some(Cow::Borrowed(&self.type_info)) + } } fn is_null(&self) -> bool { @@ -90,7 +116,13 @@ impl<'r> ValueRef<'r> for PgValueRef<'r> { } fn type_info(&self) -> Option> { - Some(Cow::Borrowed(&self.type_info)) + if self.format == PgValueFormat::Text { + // For TEXT encoding the type defined on the value is unreliable + // We don't even bother to return it so type checking is implicitly opted-out + None + } else { + Some(Cow::Borrowed(&self.type_info)) + } } fn is_null(&self) -> bool { diff --git a/sqlx-macros/src/derives/attributes.rs b/sqlx-macros/src/derives/attributes.rs index dfa2d155..e349f2e9 100644 --- a/sqlx-macros/src/derives/attributes.rs +++ b/sqlx-macros/src/derives/attributes.rs @@ -146,15 +146,12 @@ pub fn parse_child_attributes(input: &[Attribute]) -> syn::Result syn::Result<()> { +pub fn check_transparent_attributes( + input: &DeriveInput, + field: &Field, +) -> syn::Result { let attributes = parse_container_attributes(&input.attrs)?; - assert_attribute!( - attributes.transparent, - "expected #[sqlx(transparent)]", - input - ); - assert_attribute!( attributes.rename_all.is_none(), "unexpected #[sqlx(rename_all = ..)]", @@ -163,15 +160,15 @@ pub fn check_transparent_attributes(input: &DeriveInput, field: &Field) -> syn:: assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input); - let attributes = parse_child_attributes(&field.attrs)?; + let ch_attributes = parse_child_attributes(&field.attrs)?; assert_attribute!( - attributes.rename.is_none(), + ch_attributes.rename.is_none(), "unexpected #[sqlx(rename = ..)]", field ); - Ok(()) + Ok(attributes) } pub fn check_enum_attributes<'a>(input: &'a DeriveInput) -> syn::Result { diff --git a/sqlx-macros/src/derives/decode.rs b/sqlx-macros/src/derives/decode.rs index 65663a11..a3cd102c 100644 --- a/sqlx-macros/src/derives/decode.rs +++ b/sqlx-macros/src/derives/decode.rs @@ -62,21 +62,21 @@ fn expand_derive_decode_transparent( // add db type for impl generics & where clause let mut generics = generics.clone(); generics.params.insert(0, parse_quote!(DB: sqlx::Database)); - generics.params.insert(0, parse_quote!('de)); + generics.params.insert(0, parse_quote!('r)); generics .make_where_clause() .predicates - .push(parse_quote!(#ty: sqlx::decode::Decode<'de, DB>)); + .push(parse_quote!(#ty: sqlx::decode::Decode<'r, DB>)); let (impl_generics, _, where_clause) = generics.split_for_impl(); let tts = quote!( - impl #impl_generics sqlx::decode::Decode<'de, DB> for #ident #ty_generics #where_clause { + impl #impl_generics sqlx::decode::Decode<'r, DB> for #ident #ty_generics #where_clause { fn accepts(ty: &DB::TypeInfo) -> bool { - <#ty as sqlx::decode::Decode<'de, DB>>::accepts(ty) + <#ty as sqlx::decode::Decode<'r, DB>>::accepts(ty) } - fn decode(value: >::ValueRef) -> std::result::Result { - <#ty as sqlx::decode::Decode<'de, DB>>::decode(value).map(Self) + fn decode(value: >::ValueRef) -> std::result::Result> { + <#ty as sqlx::decode::Decode<'r, DB>>::decode(value).map(Self) } } ); @@ -103,13 +103,13 @@ fn expand_derive_decode_weak_enum( .collect::>(); Ok(quote!( - impl<'de, DB: sqlx::Database> sqlx::decode::Decode<'de, DB> for #ident where #repr: sqlx::decode::Decode<'de, DB> { - fn accepts(ty: &MySqlTypeInfo) -> bool { - *ty == Self::type_info() + impl<'r, DB: sqlx::Database> sqlx::decode::Decode<'r, DB> for #ident where #repr: sqlx::decode::Decode<'r, DB> { + fn accepts(ty: &DB::TypeInfo) -> bool { + <#repr as sqlx::decode::Decode<'r, DB>>::accepts(ty) } - fn decode(value: >::ValueRef) -> std::result::Result { - let value = <#repr as sqlx::decode::Decode<'de, DB>>::decode(value)?; + fn decode(value: >::ValueRef) -> std::result::Result> { + let value = <#repr as sqlx::decode::Decode<'r, DB>>::decode(value)?; match value { #(#arms)* @@ -146,22 +146,65 @@ fn expand_derive_decode_strong_enum( } }); - Ok(quote!( - impl<'de, DB: sqlx::Database> sqlx::decode::Decode<'de, DB> for #ident where &'de str: sqlx::decode::Decode<'de, DB> { - fn accepts(ty: &MySqlTypeInfo) -> bool { - *ty == Self::type_info() - } + let values = quote! { + match value { + #(#value_arms)* - fn decode(value: >::ValueRef) -> std::result::Result { - let value = <&'de str as sqlx::decode::Decode<'de, DB>>::decode(value)?; - match value { - #(#value_arms)* + _ => Err(format!("invalid value {:?} for enum {}", value, #ident_s).into()) + } + }; - _ => Err(Box::new(sqlx::Error::Decode(format!("invalid value {:?} for enum {}", value, #ident_s).into()))) + let mut tts = proc_macro2::TokenStream::new(); + + if cfg!(feature = "mysql") { + tts.extend(quote!( + impl<'r> sqlx::decode::Decode<'r, sqlx::mysql::MySql> for #ident { + fn accepts(ty: &sqlx::mysql::MySqlTypeInfo) -> bool { + ty == sqlx::mysql::MySqlTypeInfo::__enum() + } + + fn decode(value: sqlx::mysql::MySqlValueRef<'r>) -> std::result::Result> { + let value = <&'r str as sqlx::decode::Decode<'r, sqlx::mysql::MySql>>::decode(value)?; + + #values } } - } - )) + )); + } + + if cfg!(feature = "postgres") { + tts.extend(quote!( + impl<'r> sqlx::decode::Decode<'r, sqlx::postgres::Postgres> for #ident { + fn accepts(ty: &sqlx::postgres::PgTypeInfo) -> bool { + *ty == <#ident as sqlx::Type>::type_info() + } + + fn decode(value: sqlx::postgres::PgValueRef<'r>) -> std::result::Result> { + let value = <&'r str as sqlx::decode::Decode<'r, sqlx::postgres::Postgres>>::decode(value)?; + + #values + } + } + )); + } + + if cfg!(feature = "sqlite") { + tts.extend(quote!( + impl<'r> sqlx::decode::Decode<'r, sqlx::sqlite::Sqlite> for #ident { + fn accepts(ty: &sqlx::sqlite::SqliteTypeInfo) -> bool { + <&str as sqlx::decode::Decode<'r, DB>>::accepts(ty) + } + + fn decode(value: sqlx::sqlite::SqliteValueRef<'r>) -> std::result::Result> { + let value = <&'r str as sqlx::decode::Decode<'r, sqlx::sqlite::Sqlite>>::decode(value)?; + + #values + } + } + )); + } + + Ok(tts) } fn expand_derive_decode_struct( @@ -181,14 +224,14 @@ fn expand_derive_decode_struct( // add db type for impl generics & where clause let mut generics = generics.clone(); - generics.params.insert(0, parse_quote!('de)); + generics.params.insert(0, parse_quote!('r)); let predicates = &mut generics.make_where_clause().predicates; for field in fields { let ty = &field.ty; - predicates.push(parse_quote!(#ty: sqlx::decode::Decode<'de, sqlx::Postgres>)); + predicates.push(parse_quote!(#ty: sqlx::decode::Decode<'r, sqlx::Postgres>)); predicates.push(parse_quote!(#ty: sqlx::types::Type)); } @@ -199,20 +242,20 @@ fn expand_derive_decode_struct( let ty = &field.ty; parse_quote!( - let #id = decoder.decode::<#ty>()?; + let #id = decoder.try_decode::<#ty>()?; ) }); let names = fields.iter().map(|field| &field.ident); tts.extend(quote!( - impl #impl_generics sqlx::decode::Decode<'de, sqlx::Postgres> for #ident #ty_generics #where_clause { - fn accepts(ty: &MySqlTypeInfo) -> bool { - *ty == Self::type_info() + impl #impl_generics sqlx::decode::Decode<'r, sqlx::Postgres> for #ident #ty_generics #where_clause { + fn accepts(ty: &sqlx::postgres::PgTypeInfo) -> bool { + *ty == >::type_info() } - fn decode(value: >::RawValue) -> sqlx::Result { - let mut decoder = sqlx::postgres::types::raw::PgRecordDecoder::new(value)?; + fn decode(value: sqlx::postgres::PgValueRef<'r>) -> std::result::Result> { + let mut decoder = sqlx::postgres::types::PgRecordDecoder::new(value)?; #(#reads)* diff --git a/sqlx-macros/src/derives/encode.rs b/sqlx-macros/src/derives/encode.rs index af6a07d6..1560b07c 100644 --- a/sqlx-macros/src/derives/encode.rs +++ b/sqlx-macros/src/derives/encode.rs @@ -67,6 +67,7 @@ fn expand_derive_encode_transparent( generics .params .insert(0, LifetimeDef::new(lifetime.clone()).into()); + generics.params.insert(0, parse_quote!(DB: sqlx::Database)); generics .make_where_clause() @@ -76,18 +77,16 @@ fn expand_derive_encode_transparent( Ok(quote!( impl #impl_generics sqlx::encode::Encode<#lifetime, DB> for #ident #ty_generics #where_clause { - fn encode(self, buf: &mut >::ArgumentBuffer) -> sqlx::encode::IsNull { - sqlx::encode::Encode::encode(self.0, buf) + fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> sqlx::encode::IsNull { + <#ty as sqlx::encode::Encode<#lifetime, DB>>::encode_by_ref(&self.0, buf) } - fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> sqlx::encode::IsNull { - sqlx::encode::Encode::encode_by_ref(&self.0, buf) - } fn produces(&self) -> Option { - <#ty as sqlx::encode::Encode>::produces(&self.0) + <#ty as sqlx::encode::Encode<#lifetime, DB>>::produces(&self.0) } + fn size_hint(&self) -> usize { - sqlx::encode::Encode::size_hint(&self.0) + <#ty as sqlx::encode::Encode<#lifetime, DB>>::size_hint(&self.0) } } )) @@ -103,21 +102,17 @@ fn expand_derive_encode_weak_enum( let ident = &input.ident; Ok(quote!( - impl<'q, DB: sqlx::Database> sqlx::encode::Encode<'q, DB> for #ident where #repr: sqlx::encode::Encode<'q, DB> { - fn encode(self, buf: &mut >::ArgumentBuffer) -> sqlx::encode::IsNull { - sqlx::encode::Encode::encode((self as #repr), buf) - } - - fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> sqlx::encode::IsNull { - sqlx::encode::Encode::encode_by_ref(&(*self as #repr), buf) - } + impl<'q, DB: sqlx::Database> sqlx::encode::Encode<'q, DB> for #ident where #repr: sqlx::encode::Encode<'q, DB> { + fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> sqlx::encode::IsNull { + <#repr as sqlx::encode::Encode>::encode_by_ref(&(*self as #repr), buf) + } fn produces(&self) -> Option { - >::type_info().into() + <#repr as sqlx::encode::Encode>::produces(&(*self as #repr)) } fn size_hint(&self) -> usize { - sqlx::encode::Encode::size_hint(&(*self as #repr)) + <#repr as sqlx::encode::Encode>::size_hint(&(*self as #repr)) } } )) @@ -149,24 +144,21 @@ fn expand_derive_encode_strong_enum( } Ok(quote!( - impl<'q, DB: sqlx::Database> sqlx::encode::Encode<'q, DB> for #ident where str: sqlx::encode::Encode<'q, DB> { + impl<'q, DB: sqlx::Database> sqlx::encode::Encode<'q, DB> for #ident where &'q str: sqlx::encode::Encode<'q, DB> { fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> sqlx::encode::IsNull { let val = match self { #(#value_arms)* }; - >::encode_by_ref(val, buf) - } - - fn produces(&self) -> Option { - >::type_info().into() + <&str as sqlx::encode::Encode<'q, DB>>::encode(val, buf) } fn size_hint(&self) -> usize { let val = match self { #(#value_arms)* }; - >::size_hint(val) + + <&str as sqlx::encode::Encode<'q, DB>>::size_hint(&val) } } )) @@ -190,14 +182,14 @@ fn expand_derive_encode_struct( // add db type for impl generics & where clause let mut generics = generics.clone(); + let predicates = &mut generics.make_where_clause().predicates; for field in fields { let ty = &field.ty; - predicates.insert(0, parse_quote!('q)); - predicates.push(parse_quote!(#ty: sqlx::encode::Encode<'q, sqlx::Postgres>)); - predicates.push(parse_quote!(#ty: sqlx::types::Type<'q, sqlx::Postgres>)); + predicates.push(parse_quote!(#ty: for<'q> sqlx::encode::Encode<'q, sqlx::Postgres>)); + predicates.push(parse_quote!(#ty: sqlx::types::Type)); } let (impl_generics, _, where_clause) = generics.split_for_impl(); @@ -206,7 +198,6 @@ fn expand_derive_encode_struct( let id = &field.ident; parse_quote!( - // sqlx::postgres::encode_struct_field(buf, &self. #id); encoder.encode(&self. #id); ) }); @@ -221,17 +212,15 @@ fn expand_derive_encode_struct( }); tts.extend(quote!( - impl #impl_generics sqlx::encode::Encode<'q, sqlx::Postgres> for #ident #ty_generics #where_clause { - fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> sqlx::encode::IsNull { - let mut encoder = sqlx::postgres::types::raw::PgRecordEncoder::new(buf); + impl #impl_generics sqlx::encode::Encode<'_, sqlx::Postgres> for #ident #ty_generics #where_clause { + fn encode_by_ref(&self, buf: &mut sqlx::postgres::PgArgumentBuffer) -> sqlx::encode::IsNull { + let mut encoder = sqlx::postgres::types::PgRecordEncoder::new(buf); #(#writes)* - encoder.finish() - } + encoder.finish(); - fn produces(&self) -> Option { - >::type_info().into() + sqlx::encode::IsNull::No } fn size_hint(&self) -> usize { diff --git a/sqlx-macros/src/derives/type.rs b/sqlx-macros/src/derives/type.rs index cbbfe24a..fa8a9a30 100644 --- a/sqlx-macros/src/derives/type.rs +++ b/sqlx-macros/src/derives/type.rs @@ -49,32 +49,48 @@ fn expand_derive_has_sql_type_transparent( input: &DeriveInput, field: &Field, ) -> syn::Result { - check_transparent_attributes(input, field)?; + let attr = check_transparent_attributes(input, field)?; let ident = &input.ident; let ty = &field.ty; - // extract type generics let generics = &input.generics; let (_, ty_generics, _) = generics.split_for_impl(); - // add db type for clause - let mut generics = generics.clone(); - generics.params.insert(0, parse_quote!(DB: sqlx::Database)); - generics - .make_where_clause() - .predicates - .push(parse_quote!(#ty: sqlx::Type)); + if attr.transparent { + let mut generics = generics.clone(); + generics.params.insert(0, parse_quote!(DB: sqlx::Database)); + generics + .make_where_clause() + .predicates + .push(parse_quote!(#ty: sqlx::Type)); - let (impl_generics, _, where_clause) = generics.split_for_impl(); + let (impl_generics, _, where_clause) = generics.split_for_impl(); - Ok(quote!( - impl #impl_generics sqlx::Type< DB > for #ident #ty_generics #where_clause { - fn type_info() -> DB::TypeInfo { - <#ty as sqlx::Type>::type_info() + return Ok(quote!( + impl #impl_generics sqlx::Type< DB > for #ident #ty_generics #where_clause { + fn type_info() -> DB::TypeInfo { + <#ty as sqlx::Type>::type_info() + } } - } - )) + )); + } + + let mut tts = proc_macro2::TokenStream::new(); + + if cfg!(feature = "postgres") { + let ty_name = attr.rename.unwrap_or_else(|| ident.to_string()); + + tts.extend(quote!( + impl sqlx::Type< sqlx::postgres::Postgres > for #ident #ty_generics { + fn type_info() -> sqlx::postgres::PgTypeInfo { + sqlx::postgres::PgTypeInfo::with_name(#ty_name) + } + } + )); + } + + Ok(tts) } fn expand_derive_has_sql_type_weak_enum( @@ -84,8 +100,7 @@ fn expand_derive_has_sql_type_weak_enum( let attr = check_weak_enum_attributes(input, variants)?; let repr = attr.repr.unwrap(); let ident = &input.ident; - - Ok(quote!( + let ts = quote!( impl sqlx::Type for #ident where #repr: sqlx::Type, @@ -94,7 +109,9 @@ fn expand_derive_has_sql_type_weak_enum( <#repr as sqlx::Type>::type_info() } } - )) + ); + + Ok(ts) } fn expand_derive_has_sql_type_strong_enum( @@ -110,7 +127,7 @@ fn expand_derive_has_sql_type_strong_enum( tts.extend(quote!( impl sqlx::Type< sqlx::MySql > for #ident { fn type_info() -> sqlx::mysql::MySqlTypeInfo { - sqlx::mysql::MySqlTypeInfo::r#enum() + sqlx::mysql::MySqlTypeInfo::__enum() } } )); diff --git a/src/lib.rs b/src/lib.rs index d88bf113..8c84caef 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,16 +11,14 @@ pub use sqlx_core::query_as::{query_as, query_as_with}; pub use sqlx_core::query_scalar::{query_scalar, query_scalar_with}; pub use sqlx_core::row::{ColumnIndex, Row}; pub use sqlx_core::transaction::{Transaction, TransactionManager}; +pub use sqlx_core::types::Type; pub use sqlx_core::value::{Value, ValueRef}; #[doc(hidden)] pub use sqlx_core::describe; #[doc(inline)] -pub use sqlx_core::types::{self, Type}; - -#[doc(inline)] -pub use sqlx_core::error::{self, BoxDynError, Error, Result}; +pub use sqlx_core::error::{self, Error, Result}; #[cfg(feature = "mysql")] #[cfg_attr(docsrs, doc(cfg(feature = "mysql")))] @@ -42,6 +40,7 @@ pub use sqlx_core::sqlite::{self, Sqlite, SqliteConnection, SqlitePool}; #[doc(hidden)] pub extern crate sqlx_macros; +// derives #[cfg(feature = "macros")] pub use sqlx_macros::{FromRow, Type}; @@ -57,6 +56,13 @@ pub mod ty_match; #[doc(hidden)] pub mod result_ext; +pub mod types { + pub use sqlx_core::types::*; + + #[cfg(feature = "macros")] + pub use sqlx_macros::Type; +} + /// Types and traits for encoding values for the database. pub mod encode { pub use sqlx_core::encode::{Encode, IsNull}; diff --git a/tests/postgres/derives.rs b/tests/postgres/derives.rs index 32550e86..14dacdd9 100644 --- a/tests/postgres/derives.rs +++ b/tests/postgres/derives.rs @@ -1,6 +1,9 @@ -use sqlx::{postgres::PgQueryAs, Connection, Cursor, Executor, FromRow, Postgres}; +use futures::TryStreamExt; +use sqlx::{Connection, Executor, FromRow, Postgres}; +use sqlx_core::postgres::types::PgRange; use sqlx_test::{new, test_type}; use std::fmt::Debug; +use std::ops::Bound; // Transparent types are rust-side wrappers over DB types #[derive(PartialEq, Debug, sqlx::Type)] @@ -37,6 +40,7 @@ enum ColorLower { Green, Blue, } + #[derive(PartialEq, Debug, sqlx::Type)] #[sqlx(rename = "color_snake")] #[sqlx(rename_all = "snake_case")] @@ -44,6 +48,7 @@ enum ColorSnake { RedGreen, BlueBlack, } + #[derive(PartialEq, Debug, sqlx::Type)] #[sqlx(rename = "color_upper")] #[sqlx(rename_all = "uppercase")] @@ -73,36 +78,43 @@ struct InventoryItem { price: Option, } -test_type!(transparent( - Postgres, - Transparent, +// Custom range type +#[derive(sqlx::Type, Debug, PartialEq)] +#[sqlx(rename = "float_range")] +struct FloatRange(PgRange); + +// Custom domain type +#[derive(sqlx::Type, Debug)] +#[sqlx(rename = "int4rangeL0pC")] +struct RangeInclusive(PgRange); + +test_type!(transparent(Postgres, "0" == Transparent(0), "23523" == Transparent(23523) )); -test_type!(weak_enum( - Postgres, - Weak, +test_type!(weak_enum(Postgres, "0::int4" == Weak::One, "2::int4" == Weak::Two, "4::int4" == Weak::Three )); -test_type!(strong_enum( - Postgres, - Strong, +test_type!(strong_enum(Postgres, "'one'::text" == Strong::One, "'two'::text" == Strong::Two, "'four'::text" == Strong::Three )); +test_type!(floatrange(Postgres, + "'[1.234, 5.678]'::float_range" == FloatRange(PgRange::from((Bound::Included(1.234), Bound::Included(5.678)))), +)); + #[sqlx_macros::test] async fn test_enum_type() -> anyhow::Result<()> { let mut conn = new::().await?; conn.execute( r#" - DROP TABLE IF EXISTS people; DROP TYPE IF EXISTS mood CASCADE; @@ -154,7 +166,7 @@ RETURNING id let rec: PeopleRow = sqlx::query_as( " SELECT id, mood FROM people WHERE id = $1 - ", + ", ) .bind(people_id) .fetch_one(&mut conn) @@ -169,20 +181,23 @@ SELECT id, mood FROM people WHERE id = $1 let stmt = format!("SELECT id, mood FROM people WHERE id = {}", people_id); dbg!(&stmt); + let mut cursor = conn.fetch(&*stmt); - let row = cursor.next().await?.unwrap(); + let row = cursor.try_next().await?.unwrap(); let rec = PeopleRow::from_row(&row)?; assert_eq!(rec.id, people_id); assert_eq!(rec.mood, Mood::Sad); + drop(cursor); + // Normal type equivalency test let rec: (bool, Mood) = sqlx::query_as( " -SELECT $1 = 'happy'::mood, $1 - ", + SELECT $1 = 'happy'::mood, $1 + ", ) .bind(&Mood::Happy) .fetch_one(&mut conn) @@ -193,8 +208,8 @@ SELECT $1 = 'happy'::mood, $1 let rec: (bool, ColorLower) = sqlx::query_as( " -SELECT $1 = 'green'::color_lower, $1 - ", + SELECT $1 = 'green'::color_lower, $1 + ", ) .bind(&ColorLower::Green) .fetch_one(&mut conn) @@ -205,8 +220,8 @@ SELECT $1 = 'green'::color_lower, $1 let rec: (bool, ColorSnake) = sqlx::query_as( " -SELECT $1 = 'red_green'::color_snake, $1 - ", + SELECT $1 = 'red_green'::color_snake, $1 + ", ) .bind(&ColorSnake::RedGreen) .fetch_one(&mut conn) @@ -217,8 +232,8 @@ SELECT $1 = 'red_green'::color_snake, $1 let rec: (bool, ColorUpper) = sqlx::query_as( " -SELECT $1 = 'RED'::color_upper, $1 - ", + SELECT $1 = 'RED'::color_upper, $1 + ", ) .bind(&ColorUpper::Red) .fetch_one(&mut conn) @@ -234,23 +249,6 @@ SELECT $1 = 'RED'::color_upper, $1 async fn test_record_type() -> anyhow::Result<()> { let mut conn = new::().await?; - conn.execute( - r#" -DO $$ BEGIN - -CREATE TYPE inventory_item AS ( - name text, - supplier_id int, - price bigint -); - -EXCEPTION - WHEN duplicate_object THEN null; -END $$; - "#, - ) - .await?; - let value = InventoryItem { name: "fuzzy dice".to_owned(), supplier_id: Some(42), @@ -259,7 +257,7 @@ END $$; let rec: (bool, InventoryItem) = sqlx::query_as( " - SELECT $1 = ROW('fuzzy dice', 42, 199)::inventory_item, $1 +SELECT $1 = ROW('fuzzy dice', 42, 199)::inventory_item, $1 ", ) .bind(&value) @@ -275,9 +273,6 @@ END $$; #[cfg(feature = "macros")] #[sqlx_macros::test] async fn test_from_row() -> anyhow::Result<()> { - // Needed for PgQueryAs - use sqlx::prelude::*; - let mut conn = new::().await?; #[derive(sqlx::FromRow)] @@ -310,7 +305,8 @@ async fn test_from_row() -> anyhow::Result<()> { .bind(1_i32) .fetch(&mut conn); - let account = RefAccount::from_row(&cursor.next().await?.unwrap())?; + let row = cursor.try_next().await?.unwrap(); + let account = RefAccount::from_row(&row)?; assert_eq!(account.id, 1); assert_eq!(account.name, "Herp Derpinson"); @@ -321,8 +317,6 @@ async fn test_from_row() -> anyhow::Result<()> { #[cfg(feature = "macros")] #[sqlx_macros::test] async fn test_from_row_with_keyword() -> anyhow::Result<()> { - use sqlx::prelude::*; - #[derive(Debug, sqlx::FromRow)] struct AccountKeyword { r#type: i32, @@ -353,8 +347,6 @@ async fn test_from_row_with_keyword() -> anyhow::Result<()> { #[cfg(feature = "macros")] #[sqlx_macros::test] async fn test_from_row_with_rename() -> anyhow::Result<()> { - use sqlx::prelude::*; - #[derive(Debug, sqlx::FromRow)] struct AccountKeyword { #[sqlx(rename = "type")] diff --git a/tests/postgres/setup.sql b/tests/postgres/setup.sql index 651708a6..10374487 100644 --- a/tests/postgres/setup.sql +++ b/tests/postgres/setup.sql @@ -17,3 +17,9 @@ CREATE TABLE tweet text TEXT NOT NULL, owner_id BIGINT ); + +CREATE TYPE float_range AS RANGE +( + subtype = float8, + subtype_diff = float8mi +); diff --git a/tests/postgres/types.rs b/tests/postgres/types.rs index 1931a41e..a6518c2e 100644 --- a/tests/postgres/types.rs +++ b/tests/postgres/types.rs @@ -1,5 +1,8 @@ extern crate time_ as time; +use std::ops::Bound; + +use sqlx::postgres::types::PgRange; use sqlx::postgres::Postgres; use sqlx_test::{test_decode_type, test_prepared_type, test_type}; @@ -334,35 +337,26 @@ test_type!(decimal(Postgres, "12345.6789::numeric" == "12345.6789".parse::().unwrap(), )); -mod ranges { - use super::*; - use core::ops::Bound; - use sqlx::postgres::types::{Int4Range, PgRange}; +const EXC2: Bound = Bound::Excluded(2); +const EXC3: Bound = Bound::Excluded(3); +const INC1: Bound = Bound::Included(1); +const INC2: Bound = Bound::Included(2); +const UNB: Bound = Bound::Unbounded; - const EXC2: Bound = Bound::Excluded(2); - const EXC3: Bound = Bound::Excluded(3); - const INC1: Bound = Bound::Included(1); - const INC2: Bound = Bound::Included(2); - const UNB: Bound = Bound::Unbounded; - - // int4range display is hard-coded into [l, u) - test_type!(int4range>(Postgres, - - "'(,)'::int4range" == Int4Range(PgRange::new([UNB, UNB])), - "'(,]'::int4range" == Int4Range(PgRange::new([UNB, UNB])), - "'(,2)'::int4range" == Int4Range(PgRange::new([UNB, EXC2])), - "'(,2]'::int4range" == Int4Range(PgRange::new([UNB, EXC3])), - "'(1,)'::int4range" == Int4Range(PgRange::new([INC2, UNB])), - "'(1,]'::int4range" == Int4Range(PgRange::new([INC2, UNB])), - "'(1,2]'::int4range" == Int4Range(PgRange::new([INC2, EXC3])), - - "'[,)'::int4range" == Int4Range(PgRange::new([UNB, UNB])), - "'[,]'::int4range" == Int4Range(PgRange::new([UNB, UNB])), - "'[,2)'::int4range" == Int4Range(PgRange::new([UNB, EXC2])), - "'[,2]'::int4range" == Int4Range(PgRange::new([UNB, EXC3])), - "'[1,)'::int4range" == Int4Range(PgRange::new([INC1, UNB])), - "'[1,]'::int4range" == Int4Range(PgRange::new([INC1, UNB])), - "'[1,2)'::int4range" == Int4Range(PgRange::new([INC1, EXC2])), - "'[1,2]'::int4range" == Int4Range(PgRange::new([INC1, EXC3])), - )); -} +test_type!(int4range>(Postgres, + "'(,)'::int4range" == PgRange::from((UNB, UNB)), + "'(,]'::int4range" == PgRange::from((UNB, UNB)), + "'(,2)'::int4range" == PgRange::from((UNB, EXC2)), + "'(,2]'::int4range" == PgRange::from((UNB, EXC3)), + "'(1,)'::int4range" == PgRange::from((INC2, UNB)), + "'(1,]'::int4range" == PgRange::from((INC2, UNB)), + "'(1,2]'::int4range" == PgRange::from((INC2, EXC3)), + "'[,)'::int4range" == PgRange::from((UNB, UNB)), + "'[,]'::int4range" == PgRange::from((UNB, UNB)), + "'[,2)'::int4range" == PgRange::from((UNB, EXC2)), + "'[,2]'::int4range" == PgRange::from((UNB, EXC3)), + "'[1,)'::int4range" == PgRange::from((INC1, UNB)), + "'[1,]'::int4range" == PgRange::from((INC1, UNB)), + "'[1,2)'::int4range" == PgRange::from((INC1, EXC2)), + "'[1,2]'::int4range" == PgRange::from((INC1, EXC3)), +));