feat: add generic Array adapter

Signed-off-by: Austin Bonander <austin@launchbadge.com>
This commit is contained in:
Austin Bonander 2021-04-30 18:52:14 -07:00
parent 405474b575
commit 9fc011d827
No known key found for this signature in database
GPG Key ID: 4E7DA63E66AFC37E
9 changed files with 364 additions and 134 deletions

View File

@ -32,6 +32,7 @@ postgres = [
"futures-util/sink", "futures-util/sink",
"json", "json",
"dirs", "dirs",
"array"
] ]
mysql = [ mysql = [
"sha-1", "sha-1",
@ -58,6 +59,7 @@ all-types = [
"uuid", "uuid",
"bit-vec", "bit-vec",
] ]
array = []
bigdecimal = ["bigdecimal_", "num-bigint"] bigdecimal = ["bigdecimal_", "num-bigint"]
decimal = ["rust_decimal", "num-bigint"] decimal = ["rust_decimal", "num-bigint"]
json = ["serde", "serde_json"] json = ["serde", "serde_json"]

View File

@ -4,8 +4,10 @@ use crate::arguments::Arguments;
use crate::encode::{Encode, IsNull}; use crate::encode::{Encode, IsNull};
use crate::error::Error; use crate::error::Error;
use crate::ext::ustr::UStr; use crate::ext::ustr::UStr;
use crate::postgres::type_info::PgType;
use crate::postgres::{PgConnection, PgTypeInfo, Postgres}; use crate::postgres::{PgConnection, PgTypeInfo, Postgres};
use crate::types::Type; use crate::types::Type;
use std::convert::TryInto;
// TODO: buf.patch(|| ...) is a poor name, can we think of a better name? Maybe `buf.lazy(||)` ? // TODO: buf.patch(|| ...) is a poor name, can we think of a better name? Maybe `buf.lazy(||)` ?
// TODO: Extend the patch system to support dynamic lengths // TODO: Extend the patch system to support dynamic lengths
@ -141,6 +143,42 @@ impl PgArgumentBuffer {
self[offset..(offset + 4)].copy_from_slice(&len.to_be_bytes()); self[offset..(offset + 4)].copy_from_slice(&len.to_be_bytes());
} }
pub(crate) fn encode_iter<'q, T, I>(&mut self, iter: I)
where
T: Encode<'q, Postgres> + Type<Postgres>,
I: IntoIterator<Item = T>,
{
self.extend(&1_i32.to_be_bytes()); // number of dimensions
self.extend(&0_i32.to_be_bytes()); // flags
// element type
match T::type_info().0 {
PgType::DeclareWithName(name) => self.patch_type_by_name(&name),
ty => {
self.extend(&ty.oid().to_be_bytes());
}
}
let len_at = self.len();
self.extend(&[0u8; 4]); // len (initially zero but we'll fix this up)
self.extend(&1_i32.to_be_bytes()); // lower bound
// count while encoding items at the same time
let len: i32 = iter
.into_iter()
.map(|item| item.encode(self))
.count()
.try_into()
// in practice, Postgres will reject arrays significantly smaller than this:
// https://github.com/postgres/postgres/blob/e6f9539dc32473793c03cbe95bc099ee0a199c73/src/backend/utils/adt/arrayutils.c#L66
.expect("array length exceeds maximum the Postgres protocol can handle");
// fixup the actual length
self[len_at..len_at + 4].copy_from_slice(&len.to_be_bytes());
}
// Adds a callback to be invoked later when we know the parameter type // Adds a callback to be invoked later when we know the parameter type
#[allow(dead_code)] #[allow(dead_code)]
pub(crate) fn patch<F>(&mut self, callback: F) pub(crate) fn patch<F>(&mut self, callback: F)

View File

@ -1,11 +1,13 @@
use bytes::Buf; use bytes::Buf;
use crate::database::{HasArguments, HasValueRef};
use crate::decode::Decode; use crate::decode::Decode;
use crate::encode::{Encode, IsNull}; use crate::encode::{Encode, IsNull};
use crate::error::BoxDynError; use crate::error::BoxDynError;
use crate::postgres::type_info::PgType; use crate::postgres::type_info::PgType;
use crate::postgres::{PgArgumentBuffer, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; use crate::postgres::{PgArgumentBuffer, PgTypeInfo, PgValueFormat, PgValueRef, Postgres};
use crate::types::Type; use crate::types::Type;
use std::iter::FromIterator;
impl<T> Type<Postgres> for [Option<T>] impl<T> Type<Postgres> for [Option<T>]
where where
@ -33,14 +35,28 @@ where
} }
} }
impl<I> Type<Postgres> for crate::types::Array<I>
where
I: IntoIterator,
[I::Item]: Type<Postgres>,
{
fn type_info() -> PgTypeInfo {
<[I::Item] as Type<Postgres>>::type_info()
}
fn compatible(ty: &PgTypeInfo) -> bool {
<[I::Item] as Type<Postgres>>::compatible(ty)
}
}
impl<'q, T> Encode<'q, Postgres> for Vec<T> impl<'q, T> Encode<'q, Postgres> for Vec<T>
where where
for<'a> &'a [T]: Encode<'q, Postgres>, T: Encode<'q, Postgres> + Type<Postgres>,
T: Encode<'q, Postgres>,
{ {
#[inline] #[inline]
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull {
self.as_slice().encode_by_ref(buf) buf.encode_iter(self.as_slice());
IsNull::No
} }
} }
@ -49,25 +65,18 @@ where
T: Encode<'q, Postgres> + Type<Postgres>, T: Encode<'q, Postgres> + Type<Postgres>,
{ {
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull {
buf.extend(&1_i32.to_be_bytes()); // number of dimensions buf.encode_iter(*self);
buf.extend(&0_i32.to_be_bytes()); // flags IsNull::No
// element type
match T::type_info().0 {
PgType::DeclareWithName(name) => buf.patch_type_by_name(&name),
ty => {
buf.extend(&ty.oid().to_be_bytes());
}
}
buf.extend(&(self.len() as i32).to_be_bytes()); // len
buf.extend(&1_i32.to_be_bytes()); // lower bound
for element in self.iter() {
buf.encode(element);
} }
}
impl<'q, T, I> Encode<'q, Postgres> for crate::types::Array<I>
where
for<'a> &'a I: IntoIterator<Item = T>,
T: Encode<'q, Postgres> + Type<Postgres>,
{
fn encode_by_ref(&self, buf: &mut <Postgres as HasArguments<'q>>::ArgumentBuffer) -> IsNull {
buf.encode_iter(&self.0);
IsNull::No IsNull::No
} }
} }
@ -77,6 +86,27 @@ where
T: for<'a> Decode<'a, Postgres> + Type<Postgres>, T: for<'a> Decode<'a, Postgres> + Type<Postgres>,
{ {
fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> { fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
// `impl<T> FromIterator<T> for Vec<T>` is specialized for `vec::IntoIter<T>`:
// https://github.com/rust-lang/rust/blob/8a9fa3682dcf0de095ec308a29a7b19b0e011ef0/library/alloc/src/vec/spec_from_iter.rs
decode_array(value)
}
}
impl<'r, I> Decode<'r, Postgres> for crate::types::Array<I>
where
I: IntoIterator + FromIterator<<I as IntoIterator>::Item>,
I::Item: for<'a> Decode<'a, Postgres> + Type<Postgres>,
{
fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
decode_array(value).map(Self)
}
}
fn decode_array<T, I>(value: PgValueRef<'_>) -> Result<I, BoxDynError>
where
I: FromIterator<T>,
T: for<'a> Decode<'a, Postgres> + Type<Postgres>,
{
let element_type_info; let element_type_info;
let format = value.format(); let format = value.format();
@ -91,7 +121,7 @@ where
if ndim == 0 { if ndim == 0 {
// zero dimensions is an empty array // zero dimensions is an empty array
return Ok(Vec::new()); return Ok(I::from_iter(std::iter::empty()));
} }
if ndim != 1 { if ndim != 1 {
@ -118,17 +148,9 @@ where
return Err(format!("encountered an array with a lower bound of {} in the first dimension; only arrays starting at one are supported", lower).into()); return Err(format!("encountered an array with a lower bound of {} in the first dimension; only arrays starting at one are supported", lower).into());
} }
let mut elements = Vec::with_capacity(len as usize); (0..len)
.map(|_| T::decode(PgValueRef::get(&mut buf, format, element_type_info.clone())))
for _ in 0..len { .collect()
elements.push(T::decode(PgValueRef::get(
&mut buf,
format,
element_type_info.clone(),
))?)
}
Ok(elements)
} }
PgValueFormat::Text => { PgValueFormat::Text => {
@ -144,7 +166,7 @@ where
if s.is_empty() { if s.is_empty() {
// short-circuit empty arrays up here // short-circuit empty arrays up here
return Ok(Vec::new()); return Ok(I::from_iter(std::iter::empty()));
} }
// NOTE: Nearly *all* types use ',' as the sequence delimiter. Yes, there is one // NOTE: Nearly *all* types use ',' as the sequence delimiter. Yes, there is one
@ -154,17 +176,18 @@ where
// delimiter selection // delimiter selection
let delimiter = ','; let delimiter = ',';
let mut done = false;
let mut in_quotes = false; let mut in_quotes = false;
let mut in_escape = false; let mut in_escape = false;
let mut value = String::with_capacity(10); let mut value = String::with_capacity(10);
let mut chars = s.chars(); let mut chars = s.chars();
let mut elements = Vec::with_capacity(4);
while !done { std::iter::from_fn(|| {
loop { if chars.as_str().is_empty() {
match chars.next() { return None;
Some(ch) => match ch { }
for ch in &mut chars {
match ch {
_ if in_escape => { _ if in_escape => {
value.push(ch); value.push(ch);
in_escape = false; in_escape = false;
@ -185,12 +208,6 @@ where
_ => { _ => {
value.push(ch); value.push(ch);
} }
},
None => {
done = true;
break;
}
} }
} }
@ -200,18 +217,18 @@ where
Some(value.as_bytes()) Some(value.as_bytes())
}; };
elements.push(T::decode(PgValueRef { let ret = T::decode(PgValueRef {
value: value_opt, value: value_opt,
row: None, row: None,
type_info: element_type_info.clone(), type_info: element_type_info.clone(),
format, format,
})?); });
value.clear(); value.clear();
}
Ok(elements) Some(ret)
} })
.collect()
} }
} }
} }

View File

@ -83,6 +83,19 @@ impl<'q, DB: Database> Query<'q, DB, <DB as HasArguments<'q>>::Arguments> {
self self
} }
/// Bind any iterable as an array.
/// Only supported on databases with first-class arrays, like Postgres.
///
/// See also: [Array][crate::types::Array]
#[cfg(feature = "array")]
#[cfg_attr(docsrs, doc(cfg(feature = "array")))]
pub fn bind_array<I>(mut self, iter: I) -> Self
where
crate::types::Array<I>: Encode<'q, DB> + Type<DB> + Send + 'q,
{
self.bind(crate::types::Array(iter))
}
} }
impl<'q, DB, A> Query<'q, DB, A> impl<'q, DB, A> Query<'q, DB, A>

View File

@ -55,6 +55,19 @@ impl<'q, DB: Database, O> QueryAs<'q, DB, O, <DB as HasArguments<'q>>::Arguments
self.inner = self.inner.bind(value); self.inner = self.inner.bind(value);
self self
} }
/// Bind any iterable as an array.
/// Only supported on databases with first-class arrays, like Postgres.
///
/// See also: [Array][crate::types::Array]
#[cfg(feature = "array")]
#[cfg_attr(docsrs, doc(cfg(feature = "array")))]
pub fn bind_array<I>(mut self, iter: I) -> Self
where
crate::types::Array<I>: Encode<'q, DB> + Type<DB> + Send + 'q,
{
self.bind(crate::types::Array(iter))
}
} }
// FIXME: This is very close, nearly 1:1 with `Map` // FIXME: This is very close, nearly 1:1 with `Map`

View File

@ -52,6 +52,19 @@ impl<'q, DB: Database, O> QueryScalar<'q, DB, O, <DB as HasArguments<'q>>::Argum
self.inner = self.inner.bind(value); self.inner = self.inner.bind(value);
self self
} }
/// Bind any iterable as an array.
/// Only supported on databases with first-class arrays, like Postgres.
///
/// See also: [Array][crate::types::Array]
#[cfg(feature = "array")]
#[cfg_attr(docsrs, doc(cfg(feature = "array")))]
pub fn bind_array<I>(mut self, iter: I) -> Self
where
crate::types::Array<I>: Encode<'q, DB> + Type<DB> + Send + 'q,
{
self.bind(crate::types::Array(iter))
}
} }
// FIXME: This is very close, nearly 1:1 with `Map` // FIXME: This is very close, nearly 1:1 with `Map`

View File

@ -0,0 +1,110 @@
use std::ops::{Deref, DerefMut};
/// A generic adapter for encoding and decoding any type that implements
/// [`IntoIterator`][std::iter::IntoIterator]/[`FromIterator`][std::iter::FromIterator]
/// to or from an array in SQL, respectively.
///
/// Only supported on databases that have native support for arrays, such as PostgreSQL.
///
/// ## Examples
///
/// #### (Postgres) Bulk Insert with Array of Structs -> Struct of Arrays
///
/// You can implement bulk insert of structs by turning an array of structs into
/// an array for each field in the struct and then using Postgres' `UNNEST()`
///
/// ```rust,ignore
/// use sqlx::types::Array;
///
/// struct Foo {
/// bar: String,
/// baz: i32,
/// quux: bool
/// }
///
/// let foos = vec![
/// Foo {
/// bar: "bar".to_string(),
/// baz: 0,
/// quux: bool
/// }
/// ];
///
/// sqlx::query!(
/// "
/// INSERT INTO foo(bar, baz, quux)
/// SELECT * FROM UNNEST($1, $2, $3)
/// ",
/// // type overrides are necessary for the macros to accept this instead of `[String]`, etc.
/// Array(foos.iter().map(|foo| &foo.bar)) as _,
/// Array(foos.iter().map(|foo| foo.baz)) as _,
/// Array(foos.iter().map(|foo| foo.quux)) as _
/// )
/// .execute(&pool)
/// .await?;
/// ```
///
/// #### (Postgres) Deserialize a Different Type than `Vec<T>`
///
/// ```sql,ignore
/// CREATE TABLE media(
/// id BIGSERIAL PRIMARY KEY,
/// filename TEXT NOT NULL,
/// tags TEXT[] NOT NULL
/// )
/// ```
///
/// ```rust,ignore
/// use sqlx::types::Array;
///
/// use std::collections::HashSet;
///
/// struct Media {
/// id: i32,
/// filename: String,
/// tags: Array<HashSet<T>>,
/// }
///
/// let media: Vec<Media> = sqlx::query_as!(
/// r#"
/// SELECT id, filename, tags AS "tags: Array<HashSet<_>>"
/// "#
/// )
/// .fetch_all(&pool)
/// .await?;
/// ```
#[derive(Debug)]
pub struct Array<I>(pub I);
impl<I> Array<I> {
pub fn into_inner(self) -> I {
self.0
}
}
impl<I> Deref for Array<I> {
type Target = I;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<I> DerefMut for Array<I> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl<I> From<I> for Array<I> {
fn from(iterable: I) -> Self {
Self(iterable)
}
}
// orphan trait impl error
// impl<I> From<Array<I>> for I {
// fn from(array: Array<I>) -> Self {
// array.0
// }
// }

View File

@ -20,6 +20,10 @@
use crate::database::Database; use crate::database::Database;
#[cfg(feature = "array")]
#[cfg_attr(docsrs, doc(cfg(feature = "array")))]
mod array;
#[cfg(feature = "bstr")] #[cfg(feature = "bstr")]
#[cfg_attr(docsrs, doc(cfg(feature = "bstr")))] #[cfg_attr(docsrs, doc(cfg(feature = "bstr")))]
pub mod bstr; pub mod bstr;
@ -75,6 +79,9 @@ pub mod ipnetwork {
pub use ipnetwork::{IpNetwork, Ipv4Network, Ipv6Network}; pub use ipnetwork::{IpNetwork, Ipv4Network, Ipv6Network};
} }
#[cfg(feature = "array")]
pub use array::Array;
#[cfg(feature = "json")] #[cfg(feature = "json")]
pub use json::Json; pub use json::Json;

View File

@ -476,3 +476,20 @@ test_prepared_type!(money<PgMoney>(Postgres, "123.45::money" == PgMoney(12345)))
test_prepared_type!(money_vec<Vec<PgMoney>>(Postgres, test_prepared_type!(money_vec<Vec<PgMoney>>(Postgres,
"array[123.45,420.00,666.66]::money[]" == vec![PgMoney(12345), PgMoney(42000), PgMoney(66666)], "array[123.45,420.00,666.66]::money[]" == vec![PgMoney(12345), PgMoney(42000), PgMoney(66666)],
)); ));
mod array {
use sqlx::types::Array;
use std::collections::HashSet;
macro_rules! set [
($($item:expr),*) => {{
let mut set = HashSet::new();
$(set.insert($item);)*
set
}}
];
test_type!(array_to_hashset<String>(Postgres,
"array['foo', 'bar', 'baz']" == Array(set!["foo", "bar", "baz"])));
}