From 9fc011d827861894c9e395b4ff0a140b7cb92f57 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Fri, 30 Apr 2021 18:52:14 -0700 Subject: [PATCH] feat: add generic `Array` adapter Signed-off-by: Austin Bonander --- sqlx-core/Cargo.toml | 2 + sqlx-core/src/postgres/arguments.rs | 38 ++++ sqlx-core/src/postgres/types/array.rs | 285 ++++++++++++++------------ sqlx-core/src/query.rs | 13 ++ sqlx-core/src/query_as.rs | 13 ++ sqlx-core/src/query_scalar.rs | 13 ++ sqlx-core/src/types/array.rs | 110 ++++++++++ sqlx-core/src/types/mod.rs | 7 + tests/postgres/types.rs | 17 ++ 9 files changed, 364 insertions(+), 134 deletions(-) create mode 100644 sqlx-core/src/types/array.rs diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index 8f867869..6e6c8e7b 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -32,6 +32,7 @@ postgres = [ "futures-util/sink", "json", "dirs", + "array" ] mysql = [ "sha-1", @@ -58,6 +59,7 @@ all-types = [ "uuid", "bit-vec", ] +array = [] bigdecimal = ["bigdecimal_", "num-bigint"] decimal = ["rust_decimal", "num-bigint"] json = ["serde", "serde_json"] diff --git a/sqlx-core/src/postgres/arguments.rs b/sqlx-core/src/postgres/arguments.rs index 9bd60dbb..ede9fcfd 100644 --- a/sqlx-core/src/postgres/arguments.rs +++ b/sqlx-core/src/postgres/arguments.rs @@ -4,8 +4,10 @@ use crate::arguments::Arguments; use crate::encode::{Encode, IsNull}; use crate::error::Error; use crate::ext::ustr::UStr; +use crate::postgres::type_info::PgType; use crate::postgres::{PgConnection, PgTypeInfo, Postgres}; use crate::types::Type; +use std::convert::TryInto; // TODO: buf.patch(|| ...) is a poor name, can we think of a better name? Maybe `buf.lazy(||)` ? // TODO: Extend the patch system to support dynamic lengths @@ -141,6 +143,42 @@ impl PgArgumentBuffer { self[offset..(offset + 4)].copy_from_slice(&len.to_be_bytes()); } + pub(crate) fn encode_iter<'q, T, I>(&mut self, iter: I) + where + T: Encode<'q, Postgres> + Type, + I: IntoIterator, + { + self.extend(&1_i32.to_be_bytes()); // number of dimensions + self.extend(&0_i32.to_be_bytes()); // flags + + // element type + match T::type_info().0 { + PgType::DeclareWithName(name) => self.patch_type_by_name(&name), + + ty => { + self.extend(&ty.oid().to_be_bytes()); + } + } + + let len_at = self.len(); + + self.extend(&[0u8; 4]); // len (initially zero but we'll fix this up) + self.extend(&1_i32.to_be_bytes()); // lower bound + + // count while encoding items at the same time + let len: i32 = iter + .into_iter() + .map(|item| item.encode(self)) + .count() + .try_into() + // in practice, Postgres will reject arrays significantly smaller than this: + // https://github.com/postgres/postgres/blob/e6f9539dc32473793c03cbe95bc099ee0a199c73/src/backend/utils/adt/arrayutils.c#L66 + .expect("array length exceeds maximum the Postgres protocol can handle"); + + // fixup the actual length + self[len_at..len_at + 4].copy_from_slice(&len.to_be_bytes()); + } + // Adds a callback to be invoked later when we know the parameter type #[allow(dead_code)] pub(crate) fn patch(&mut self, callback: F) diff --git a/sqlx-core/src/postgres/types/array.rs b/sqlx-core/src/postgres/types/array.rs index cf2baea4..c58b3c7f 100644 --- a/sqlx-core/src/postgres/types/array.rs +++ b/sqlx-core/src/postgres/types/array.rs @@ -1,11 +1,13 @@ use bytes::Buf; +use crate::database::{HasArguments, HasValueRef}; use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::postgres::type_info::PgType; use crate::postgres::{PgArgumentBuffer, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; use crate::types::Type; +use std::iter::FromIterator; impl Type for [Option] where @@ -33,14 +35,28 @@ where } } +impl Type for crate::types::Array +where + I: IntoIterator, + [I::Item]: Type, +{ + fn type_info() -> PgTypeInfo { + <[I::Item] as Type>::type_info() + } + + fn compatible(ty: &PgTypeInfo) -> bool { + <[I::Item] as Type>::compatible(ty) + } +} + impl<'q, T> Encode<'q, Postgres> for Vec where - for<'a> &'a [T]: Encode<'q, Postgres>, - T: Encode<'q, Postgres>, + T: Encode<'q, Postgres> + Type, { #[inline] fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { - self.as_slice().encode_by_ref(buf) + buf.encode_iter(self.as_slice()); + IsNull::No } } @@ -49,25 +65,18 @@ where T: Encode<'q, Postgres> + Type, { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { - buf.extend(&1_i32.to_be_bytes()); // number of dimensions - buf.extend(&0_i32.to_be_bytes()); // flags - - // element type - match T::type_info().0 { - PgType::DeclareWithName(name) => buf.patch_type_by_name(&name), - - ty => { - buf.extend(&ty.oid().to_be_bytes()); - } - } - - buf.extend(&(self.len() as i32).to_be_bytes()); // len - buf.extend(&1_i32.to_be_bytes()); // lower bound - - for element in self.iter() { - buf.encode(element); - } + buf.encode_iter(*self); + IsNull::No + } +} +impl<'q, T, I> Encode<'q, Postgres> for crate::types::Array +where + for<'a> &'a I: IntoIterator, + T: Encode<'q, Postgres> + Type, +{ + fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> IsNull { + buf.encode_iter(&self.0); IsNull::No } } @@ -77,141 +86,149 @@ where T: for<'a> Decode<'a, Postgres> + Type, { fn decode(value: PgValueRef<'r>) -> Result { - let element_type_info; - let format = value.format(); + // `impl FromIterator for Vec` is specialized for `vec::IntoIter`: + // https://github.com/rust-lang/rust/blob/8a9fa3682dcf0de095ec308a29a7b19b0e011ef0/library/alloc/src/vec/spec_from_iter.rs + decode_array(value) + } +} - match format { - PgValueFormat::Binary => { - // https://github.com/postgres/postgres/blob/a995b371ae29de2d38c4b7881cf414b1560e9746/src/backend/utils/adt/arrayfuncs.c#L1548 +impl<'r, I> Decode<'r, Postgres> for crate::types::Array +where + I: IntoIterator + FromIterator<::Item>, + I::Item: for<'a> Decode<'a, Postgres> + Type, +{ + fn decode(value: PgValueRef<'r>) -> Result { + decode_array(value).map(Self) + } +} - let mut buf = value.as_bytes()?; +fn decode_array(value: PgValueRef<'_>) -> Result +where + I: FromIterator, + T: for<'a> Decode<'a, Postgres> + Type, +{ + let element_type_info; + let format = value.format(); - // number of dimensions in the array - let ndim = buf.get_i32(); + match format { + PgValueFormat::Binary => { + // https://github.com/postgres/postgres/blob/a995b371ae29de2d38c4b7881cf414b1560e9746/src/backend/utils/adt/arrayfuncs.c#L1548 - if ndim == 0 { - // zero dimensions is an empty array - return Ok(Vec::new()); - } + let mut buf = value.as_bytes()?; - if ndim != 1 { - return Err(format!("encountered an array of {} dimensions; only one-dimensional arrays are supported", ndim).into()); - } + // number of dimensions in the array + let ndim = buf.get_i32(); - // appears to have been used in the past to communicate potential NULLS - // but reading source code back through our supported postgres versions (9.5+) - // this is never used for anything - let _flags = buf.get_i32(); - - // the OID of the element - let element_type_oid = buf.get_u32(); - element_type_info = PgTypeInfo::try_from_oid(element_type_oid) - .unwrap_or_else(|| PgTypeInfo(PgType::DeclareWithOid(element_type_oid))); - - // length of the array axis - let len = buf.get_i32(); - - // the lower bound, we only support arrays starting from "1" - let lower = buf.get_i32(); - - if lower != 1 { - return Err(format!("encountered an array with a lower bound of {} in the first dimension; only arrays starting at one are supported", lower).into()); - } - - let mut elements = Vec::with_capacity(len as usize); - - for _ in 0..len { - elements.push(T::decode(PgValueRef::get( - &mut buf, - format, - element_type_info.clone(), - ))?) - } - - Ok(elements) + if ndim == 0 { + // zero dimensions is an empty array + return Ok(I::from_iter(std::iter::empty())); } - PgValueFormat::Text => { - // no type is provided from the database for the element - element_type_info = T::type_info(); + if ndim != 1 { + return Err(format!("encountered an array of {} dimensions; only one-dimensional arrays are supported", ndim).into()); + } - let s = value.as_str()?; + // appears to have been used in the past to communicate potential NULLS + // but reading source code back through our supported postgres versions (9.5+) + // this is never used for anything + let _flags = buf.get_i32(); - // https://github.com/postgres/postgres/blob/a995b371ae29de2d38c4b7881cf414b1560e9746/src/backend/utils/adt/arrayfuncs.c#L718 + // the OID of the element + let element_type_oid = buf.get_u32(); + element_type_info = PgTypeInfo::try_from_oid(element_type_oid) + .unwrap_or_else(|| PgTypeInfo(PgType::DeclareWithOid(element_type_oid))); - // trim the wrapping braces - let s = &s[1..(s.len() - 1)]; + // length of the array axis + let len = buf.get_i32(); - if s.is_empty() { - // short-circuit empty arrays up here - return Ok(Vec::new()); + // the lower bound, we only support arrays starting from "1" + let lower = buf.get_i32(); + + if lower != 1 { + return Err(format!("encountered an array with a lower bound of {} in the first dimension; only arrays starting at one are supported", lower).into()); + } + + (0..len) + .map(|_| T::decode(PgValueRef::get(&mut buf, format, element_type_info.clone()))) + .collect() + } + + PgValueFormat::Text => { + // no type is provided from the database for the element + element_type_info = T::type_info(); + + let s = value.as_str()?; + + // https://github.com/postgres/postgres/blob/a995b371ae29de2d38c4b7881cf414b1560e9746/src/backend/utils/adt/arrayfuncs.c#L718 + + // trim the wrapping braces + let s = &s[1..(s.len() - 1)]; + + if s.is_empty() { + // short-circuit empty arrays up here + return Ok(I::from_iter(std::iter::empty())); + } + + // NOTE: Nearly *all* types use ',' as the sequence delimiter. Yes, there is one + // that does not. The BOX (not PostGIS) type uses ';' as a delimiter. + + // TODO: When we add support for BOX we need to figure out some way to make the + // delimiter selection + + let delimiter = ','; + let mut in_quotes = false; + let mut in_escape = false; + let mut value = String::with_capacity(10); + let mut chars = s.chars(); + + std::iter::from_fn(|| { + if chars.as_str().is_empty() { + return None; } - // NOTE: Nearly *all* types use ',' as the sequence delimiter. Yes, there is one - // that does not. The BOX (not PostGIS) type uses ';' as a delimiter. + for ch in &mut chars { + match ch { + _ if in_escape => { + value.push(ch); + in_escape = false; + } - // TODO: When we add support for BOX we need to figure out some way to make the - // delimiter selection + '"' => { + in_quotes = !in_quotes; + } - let delimiter = ','; - let mut done = false; - let mut in_quotes = false; - let mut in_escape = false; - let mut value = String::with_capacity(10); - let mut chars = s.chars(); - let mut elements = Vec::with_capacity(4); + '\\' => { + in_escape = true; + } - while !done { - loop { - match chars.next() { - Some(ch) => match ch { - _ if in_escape => { - value.push(ch); - in_escape = false; - } + _ if ch == delimiter && !in_quotes => { + break; + } - '"' => { - in_quotes = !in_quotes; - } - - '\\' => { - in_escape = true; - } - - _ if ch == delimiter && !in_quotes => { - break; - } - - _ => { - value.push(ch); - } - }, - - None => { - done = true; - break; - } + _ => { + value.push(ch); } } - - let value_opt = if value == "NULL" { - None - } else { - Some(value.as_bytes()) - }; - - elements.push(T::decode(PgValueRef { - value: value_opt, - row: None, - type_info: element_type_info.clone(), - format, - })?); - - value.clear(); } - Ok(elements) - } + let value_opt = if value == "NULL" { + None + } else { + Some(value.as_bytes()) + }; + + let ret = T::decode(PgValueRef { + value: value_opt, + row: None, + type_info: element_type_info.clone(), + format, + }); + + value.clear(); + + Some(ret) + }) + .collect() } } } diff --git a/sqlx-core/src/query.rs b/sqlx-core/src/query.rs index b3e30dc5..361f9057 100644 --- a/sqlx-core/src/query.rs +++ b/sqlx-core/src/query.rs @@ -83,6 +83,19 @@ impl<'q, DB: Database> Query<'q, DB, >::Arguments> { self } + + /// Bind any iterable as an array. + /// Only supported on databases with first-class arrays, like Postgres. + /// + /// See also: [Array][crate::types::Array] + #[cfg(feature = "array")] + #[cfg_attr(docsrs, doc(cfg(feature = "array")))] + pub fn bind_array(mut self, iter: I) -> Self + where + crate::types::Array: Encode<'q, DB> + Type + Send + 'q, + { + self.bind(crate::types::Array(iter)) + } } impl<'q, DB, A> Query<'q, DB, A> diff --git a/sqlx-core/src/query_as.rs b/sqlx-core/src/query_as.rs index 62406c21..a5631003 100644 --- a/sqlx-core/src/query_as.rs +++ b/sqlx-core/src/query_as.rs @@ -55,6 +55,19 @@ impl<'q, DB: Database, O> QueryAs<'q, DB, O, >::Arguments self.inner = self.inner.bind(value); self } + + /// Bind any iterable as an array. + /// Only supported on databases with first-class arrays, like Postgres. + /// + /// See also: [Array][crate::types::Array] + #[cfg(feature = "array")] + #[cfg_attr(docsrs, doc(cfg(feature = "array")))] + pub fn bind_array(mut self, iter: I) -> Self + where + crate::types::Array: Encode<'q, DB> + Type + Send + 'q, + { + self.bind(crate::types::Array(iter)) + } } // FIXME: This is very close, nearly 1:1 with `Map` diff --git a/sqlx-core/src/query_scalar.rs b/sqlx-core/src/query_scalar.rs index 7e958a7b..760d3a80 100644 --- a/sqlx-core/src/query_scalar.rs +++ b/sqlx-core/src/query_scalar.rs @@ -52,6 +52,19 @@ impl<'q, DB: Database, O> QueryScalar<'q, DB, O, >::Argum self.inner = self.inner.bind(value); self } + + /// Bind any iterable as an array. + /// Only supported on databases with first-class arrays, like Postgres. + /// + /// See also: [Array][crate::types::Array] + #[cfg(feature = "array")] + #[cfg_attr(docsrs, doc(cfg(feature = "array")))] + pub fn bind_array(mut self, iter: I) -> Self + where + crate::types::Array: Encode<'q, DB> + Type + Send + 'q, + { + self.bind(crate::types::Array(iter)) + } } // FIXME: This is very close, nearly 1:1 with `Map` diff --git a/sqlx-core/src/types/array.rs b/sqlx-core/src/types/array.rs new file mode 100644 index 00000000..fedc3871 --- /dev/null +++ b/sqlx-core/src/types/array.rs @@ -0,0 +1,110 @@ +use std::ops::{Deref, DerefMut}; + +/// A generic adapter for encoding and decoding any type that implements +/// [`IntoIterator`][std::iter::IntoIterator]/[`FromIterator`][std::iter::FromIterator] +/// to or from an array in SQL, respectively. +/// +/// Only supported on databases that have native support for arrays, such as PostgreSQL. +/// +/// ## Examples +/// +/// #### (Postgres) Bulk Insert with Array of Structs -> Struct of Arrays +/// +/// You can implement bulk insert of structs by turning an array of structs into +/// an array for each field in the struct and then using Postgres' `UNNEST()` +/// +/// ```rust,ignore +/// use sqlx::types::Array; +/// +/// struct Foo { +/// bar: String, +/// baz: i32, +/// quux: bool +/// } +/// +/// let foos = vec![ +/// Foo { +/// bar: "bar".to_string(), +/// baz: 0, +/// quux: bool +/// } +/// ]; +/// +/// sqlx::query!( +/// " +/// INSERT INTO foo(bar, baz, quux) +/// SELECT * FROM UNNEST($1, $2, $3) +/// ", +/// // type overrides are necessary for the macros to accept this instead of `[String]`, etc. +/// Array(foos.iter().map(|foo| &foo.bar)) as _, +/// Array(foos.iter().map(|foo| foo.baz)) as _, +/// Array(foos.iter().map(|foo| foo.quux)) as _ +/// ) +/// .execute(&pool) +/// .await?; +/// ``` +/// +/// #### (Postgres) Deserialize a Different Type than `Vec` +/// +/// ```sql,ignore +/// CREATE TABLE media( +/// id BIGSERIAL PRIMARY KEY, +/// filename TEXT NOT NULL, +/// tags TEXT[] NOT NULL +/// ) +/// ``` +/// +/// ```rust,ignore +/// use sqlx::types::Array; +/// +/// use std::collections::HashSet; +/// +/// struct Media { +/// id: i32, +/// filename: String, +/// tags: Array>, +/// } +/// +/// let media: Vec = sqlx::query_as!( +/// r#" +/// SELECT id, filename, tags AS "tags: Array>" +/// "# +/// ) +/// .fetch_all(&pool) +/// .await?; +/// ``` +#[derive(Debug)] +pub struct Array(pub I); + +impl Array { + pub fn into_inner(self) -> I { + self.0 + } +} + +impl Deref for Array { + type Target = I; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for Array { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl From for Array { + fn from(iterable: I) -> Self { + Self(iterable) + } +} + +// orphan trait impl error +// impl From> for I { +// fn from(array: Array) -> Self { +// array.0 +// } +// } diff --git a/sqlx-core/src/types/mod.rs b/sqlx-core/src/types/mod.rs index 09c4de85..e47c55c6 100644 --- a/sqlx-core/src/types/mod.rs +++ b/sqlx-core/src/types/mod.rs @@ -20,6 +20,10 @@ use crate::database::Database; +#[cfg(feature = "array")] +#[cfg_attr(docsrs, doc(cfg(feature = "array")))] +mod array; + #[cfg(feature = "bstr")] #[cfg_attr(docsrs, doc(cfg(feature = "bstr")))] pub mod bstr; @@ -75,6 +79,9 @@ pub mod ipnetwork { pub use ipnetwork::{IpNetwork, Ipv4Network, Ipv6Network}; } +#[cfg(feature = "array")] +pub use array::Array; + #[cfg(feature = "json")] pub use json::Json; diff --git a/tests/postgres/types.rs b/tests/postgres/types.rs index a0aa64eb..fe476f53 100644 --- a/tests/postgres/types.rs +++ b/tests/postgres/types.rs @@ -476,3 +476,20 @@ test_prepared_type!(money(Postgres, "123.45::money" == PgMoney(12345))) test_prepared_type!(money_vec>(Postgres, "array[123.45,420.00,666.66]::money[]" == vec![PgMoney(12345), PgMoney(42000), PgMoney(66666)], )); + +mod array { + use sqlx::types::Array; + + use std::collections::HashSet; + + macro_rules! set [ + ($($item:expr),*) => {{ + let mut set = HashSet::new(); + $(set.insert($item);)* + set + }} + ]; + + test_type!(array_to_hashset(Postgres, + "array['foo', 'bar', 'baz']" == Array(set!["foo", "bar", "baz"]))); +}