diff --git a/sqlx-core/src/postgres/types/array.rs b/sqlx-core/src/postgres/types/array.rs new file mode 100644 index 00000000..11f10d5e --- /dev/null +++ b/sqlx-core/src/postgres/types/array.rs @@ -0,0 +1,202 @@ +/// Encoding and decoding of Postgres arrays. Documentation of the byte format can be found [here](https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/include/utils/array.h;h=7f7e744cb12bc872f628f90dad99dfdf074eb314;hb=master#l6) +use crate::decode::Decode; +use crate::decode::DecodeError; +use crate::encode::Encode; +use crate::io::{Buf, BufMut}; +use crate::postgres::database::Postgres; +use crate::types::HasSqlType; +use std::marker::PhantomData; + +impl Encode for [T] +where + T: Encode, + Postgres: HasSqlType, +{ + fn encode(&self, buf: &mut Vec) { + let mut encoder = ArrayEncoder::new(buf); + for item in self { + encoder.push(item); + } + } +} +impl Encode for Vec +where + [T]: Encode, + Postgres: HasSqlType, +{ + fn encode(&self, buf: &mut Vec) { + self.as_slice().encode(buf) + } +} + +impl Decode for Vec +where + T: Decode, + Postgres: HasSqlType, +{ + fn decode(buf: &[u8]) -> Result { + let decoder = ArrayDecoder::::new(buf)?; + decoder.collect() + } +} + +type Order = byteorder::BigEndian; + +struct ArrayDecoder<'a, T> +where + T: Decode, + Postgres: HasSqlType, +{ + left: usize, + did_error: bool, + + buf: &'a [u8], + + phantom: PhantomData, +} + +impl ArrayDecoder<'_, T> +where + T: Decode, + Postgres: HasSqlType, +{ + fn new(mut buf: &[u8]) -> Result, DecodeError> { + let ndim = buf.get_i32::()?; + let dataoffset = buf.get_i32::()?; + let elemtype = buf.get_i32::()?; + + if ndim == 0 { + return Ok(ArrayDecoder { + left: 0, + did_error: false, + buf, + phantom: PhantomData, + }); + } + + assert_eq!(ndim, 1, "only arrays of dimension 1 is supported"); + + let dimensions = buf.get_i32::()?; + let lower_bnds = buf.get_i32::()?; + + assert_eq!(dataoffset, 0, "arrays with [null bitmap] is not supported"); + assert_eq!( + elemtype, + >::type_info().id.0 as i32, + "mismatched array element type" + ); + assert_eq!(lower_bnds, 1); + + Ok(ArrayDecoder { + left: dimensions as usize, + did_error: false, + buf, + + phantom: PhantomData, + }) + } + + /// Decodes the next element without worring how many are left, or if it previously errored + fn decode_next_element(&mut self) -> Result { + let len = self.buf.get_i32::()?; + let bytes = self.buf.get_bytes(len as usize)?; + Decode::decode(bytes) + } +} + +impl Iterator for ArrayDecoder<'_, T> +where + T: Decode, + Postgres: HasSqlType, +{ + type Item = Result; + + fn next(&mut self) -> Option> { + if self.did_error || self.left == 0 { + return None; + } + + self.left -= 1; + + let decoded = self.decode_next_element(); + self.did_error = decoded.is_err(); + Some(decoded) + } +} + +struct ArrayEncoder<'a, T> +where + T: Encode, + Postgres: HasSqlType, +{ + count: usize, + len_start_index: usize, + buf: &'a mut Vec, + + phantom: PhantomData, +} + +impl ArrayEncoder<'_, T> +where + T: Encode, + Postgres: HasSqlType, +{ + fn new(buf: &mut Vec) -> ArrayEncoder { + let ty = >::type_info(); + + // ndim + buf.put_i32::(1); + // dataoffset + buf.put_i32::(0); + // elemtype + buf.put_i32::(ty.id.0 as i32); + let len_start_index = buf.len(); + // dimensions + buf.put_i32::(0); + // lower_bnds + buf.put_i32::(1); + + ArrayEncoder { + count: 0, + len_start_index, + buf, + + phantom: PhantomData, + } + } + fn push(&mut self, item: &T) { + // Allocate space for the length of the encoded elemement up front + let el_len_index = self.buf.len(); + self.buf.put_i32::(0); + + // Allocate the element it self + let el_start = self.buf.len(); + Encode::encode(item, self.buf); + let el_end = self.buf.len(); + + // Now we know the actual length of the encoded element + let el_len = el_end - el_start; + + // And we can now go back and update the length + self.buf[el_len_index..el_start].copy_from_slice(&(el_len as i32).to_be_bytes()); + + self.count += 1; + } + fn update_len(&mut self) { + const I32_SIZE: usize = std::mem::size_of::(); + + let size_bytes = (self.count as i32).to_be_bytes(); + + self.buf[self.len_start_index..self.len_start_index + I32_SIZE] + .copy_from_slice(&size_bytes); + } +} +impl Drop for ArrayEncoder<'_, T> +where + T: Encode, + Postgres: HasSqlType, +{ + fn drop(&mut self) { + self.update_len(); + } +} diff --git a/sqlx-core/src/postgres/types/bool.rs b/sqlx-core/src/postgres/types/bool.rs index 7f850dc1..4deabfb8 100644 --- a/sqlx-core/src/postgres/types/bool.rs +++ b/sqlx-core/src/postgres/types/bool.rs @@ -19,6 +19,11 @@ impl Type for [bool] { PgTypeInfo::new(TypeId::ARRAY_BOOL, "BOOL[]") } } +impl HasSqlType> for Postgres { + fn type_info() -> PgTypeInfo { + >::type_info() + } +} impl Encode for bool { fn encode(&self, buf: &mut Vec) { diff --git a/sqlx-core/src/postgres/types/bytes.rs b/sqlx-core/src/postgres/types/bytes.rs index 5e0a81a6..a2f77fe9 100644 --- a/sqlx-core/src/postgres/types/bytes.rs +++ b/sqlx-core/src/postgres/types/bytes.rs @@ -19,6 +19,12 @@ impl Type for [&'_ [u8]] { } } +impl Type for Vec<&'_ [u8]> { + fn type_info() -> PgTypeInfo { + <&'_ [u8] as Type>::type_info() + } +} + impl Type for Vec { fn type_info() -> PgTypeInfo { <[u8] as Type>::type_info() diff --git a/sqlx-core/src/postgres/types/chrono.rs b/sqlx-core/src/postgres/types/chrono.rs index f058f9f7..ab273e5f 100644 --- a/sqlx-core/src/postgres/types/chrono.rs +++ b/sqlx-core/src/postgres/types/chrono.rs @@ -67,6 +67,33 @@ where } } +impl Type for Vec { + fn type_info() -> PgTypeInfo { + <[NaiveTime] as Type>::type_info() + } +} + +impl Type for Vec { + fn type_info() -> PgTypeInfo { + <[NaiveDate] as Type>::type_info() + } +} + +impl Type for Vec { + fn type_info() -> PgTypeInfo { + <[NaiveDateTime] as Type>::type_info() + } +} + +impl Type for Vec> +where + Tz: TimeZone, +{ + fn type_info() -> PgTypeInfo { + <[NaiveDateTime] as Type>::type_info() + } +} + impl<'de> Decode<'de, Postgres> for NaiveTime { fn decode(value: Option>) -> crate::Result { match value.try_into()? { diff --git a/sqlx-core/src/postgres/types/float.rs b/sqlx-core/src/postgres/types/float.rs index b434525d..51614e39 100644 --- a/sqlx-core/src/postgres/types/float.rs +++ b/sqlx-core/src/postgres/types/float.rs @@ -22,6 +22,11 @@ impl Type for [f32] { PgTypeInfo::new(TypeId::ARRAY_FLOAT4, "FLOAT4[]") } } +impl HasSqlType> for Postgres { + fn type_info() -> PgTypeInfo { + >::type_info() + } +} impl Encode for f32 { fn encode(&self, buf: &mut Vec) { @@ -53,6 +58,11 @@ impl Type for [f64] { PgTypeInfo::new(TypeId::ARRAY_FLOAT8, "FLOAT8[]") } } +impl HasSqlType> for Postgres { + fn type_info() -> PgTypeInfo { + >::type_info() + } +} impl Encode for f64 { fn encode(&self, buf: &mut Vec) { diff --git a/sqlx-core/src/postgres/types/int.rs b/sqlx-core/src/postgres/types/int.rs index f1d1d3ba..96cddc9b 100644 --- a/sqlx-core/src/postgres/types/int.rs +++ b/sqlx-core/src/postgres/types/int.rs @@ -22,6 +22,11 @@ impl Type for [i16] { PgTypeInfo::new(TypeId::ARRAY_INT2, "INT2[]") } } +impl HasSqlType> for Postgres { + fn type_info() -> PgTypeInfo { + >::type_info() + } +} impl Encode for i16 { fn encode(&self, buf: &mut Vec) { @@ -49,6 +54,11 @@ impl Type for [i32] { PgTypeInfo::new(TypeId::ARRAY_INT4, "INT4[]") } } +impl HasSqlType> for Postgres { + fn type_info() -> PgTypeInfo { + >::type_info() + } +} impl Encode for i32 { fn encode(&self, buf: &mut Vec) { @@ -76,6 +86,11 @@ impl Type for [i64] { PgTypeInfo::new(TypeId::ARRAY_INT8, "INT8[]") } } +impl HasSqlType> for Postgres { + fn type_info() -> PgTypeInfo { + >::type_info() + } +} impl Encode for i64 { fn encode(&self, buf: &mut Vec) { diff --git a/sqlx-core/src/postgres/types/mod.rs b/sqlx-core/src/postgres/types/mod.rs index f1f471da..c252fc1b 100644 --- a/sqlx-core/src/postgres/types/mod.rs +++ b/sqlx-core/src/postgres/types/mod.rs @@ -59,6 +59,7 @@ use crate::postgres::protocol::TypeId; use crate::postgres::{PgValue, Postgres}; use crate::types::TypeInfo; +mod array; mod bool; mod bytes; mod float; diff --git a/sqlx-core/src/postgres/types/str.rs b/sqlx-core/src/postgres/types/str.rs index b2f77d65..6012e594 100644 --- a/sqlx-core/src/postgres/types/str.rs +++ b/sqlx-core/src/postgres/types/str.rs @@ -21,12 +21,27 @@ impl Type for [&'_ str] { PgTypeInfo::new(TypeId::ARRAY_TEXT, "TEXT[]") } } +impl HasSqlType> for Postgres { + fn type_info() -> PgTypeInfo { + >::type_info() + } +} impl Type for String { fn type_info() -> PgTypeInfo { >::type_info() } } +impl HasSqlType<[String]> for Postgres { + fn type_info() -> PgTypeInfo { + >::type_info() + } +} +impl HasSqlType> for Postgres { + fn type_info() -> PgTypeInfo { + >>::type_info() + } +} impl Encode for str { fn encode(&self, buf: &mut Vec) { diff --git a/sqlx-core/src/postgres/types/uuid.rs b/sqlx-core/src/postgres/types/uuid.rs index b317968e..a5aae947 100644 --- a/sqlx-core/src/postgres/types/uuid.rs +++ b/sqlx-core/src/postgres/types/uuid.rs @@ -23,6 +23,12 @@ impl Type for [Uuid] { } } +impl HasSqlType> for Postgres { + fn type_info() -> PgTypeInfo { + >::type_info() + } +} + impl Encode for Uuid { fn encode(&self, buf: &mut Vec) { buf.extend_from_slice(self.as_bytes()); diff --git a/sqlx-macros/src/database/postgres.rs b/sqlx-macros/src/database/postgres.rs index 982bebad..3faac941 100644 --- a/sqlx-macros/src/database/postgres.rs +++ b/sqlx-macros/src/database/postgres.rs @@ -30,7 +30,16 @@ impl_database_ext! { sqlx::types::BigDecimal, #[cfg(feature = "ipnetwork")] - sqlx::types::ipnetwork::IpNetwork + sqlx::types::ipnetwork::IpNetwork, + + // Arrays + Vec | &[bool], + Vec | &[String], + Vec | &[i16], + Vec | &[i32], + Vec | &[i64], + Vec | &[f32], + Vec | &[f64], }, ParamChecking::Strong, feature-types: info => info.type_feature_gate(), diff --git a/tests/postgres-macros.rs b/tests/postgres-macros.rs index d814b1c0..bcea189d 100644 --- a/tests/postgres-macros.rs +++ b/tests/postgres-macros.rs @@ -186,6 +186,32 @@ async fn test_many_args() -> anyhow::Result<()> { Ok(()) } +#[cfg_attr(feature = "runtime-async-std", async_std::test)] +#[cfg_attr(feature = "runtime-tokio", tokio::test)] +async fn test_array_from_slice() -> anyhow::Result<()> { + let mut conn = connect().await?; + + let list: &[i32] = &[1, 2, 3, 4i32]; + + let result = sqlx::query!("SELECT $1::int[] as my_array", *list) + .fetch_one(&mut conn) + .await?; + + assert_eq!(result.my_array, vec![1, 2, 3, 4]); + + println!("result ID: {:?}", result.my_array); + + let account = sqlx::query!("SELECT ARRAY[4,3,2,1] as my_array") + .fetch_one(&mut conn) + .await?; + + assert_eq!(account.my_array, vec![4, 3, 2, 1]); + + println!("account ID: {:?}", account.my_array); + + Ok(()) +} + async fn connect() -> anyhow::Result { let _ = dotenv::dotenv(); let _ = env_logger::try_init(); diff --git a/tests/postgres-types.rs b/tests/postgres-types.rs index 865210d2..59029342 100644 --- a/tests/postgres-types.rs +++ b/tests/postgres-types.rs @@ -282,6 +282,12 @@ async fn test_unprepared_anonymous_record() -> anyhow::Result<()> { Ok(()) } +test!(postgres_int_vec: Vec: "ARRAY[1, 2, 3]::int[]" == vec![1, 2, 3i32], "ARRAY[3, 292, 15, 2, 3]::int[]" == vec![3, 292, 15, 2, 3], "ARRAY[7, 6, 5, 4, 3, 2, 1]::int[]" == vec![7, 6, 5, 4, 3, 2, 1], "ARRAY[]::int[]" == vec![] as Vec); +test!(postgres_string_vec: Vec: "ARRAY['Hello', 'world', 'friend']::text[]" == vec!["Hello", "world", "friend"]); +test!(postgres_bool_vec: Vec: "ARRAY[true, true, false, true]::bool[]" == vec![true, true, false, true]); +test!(postgres_real_vec: Vec: "ARRAY[0.0, 1.0, 3.14, 1.234, -0.002, 100000.0]::real[]" == vec![0.0, 1.0, 3.14, 1.234, -0.002, 100000.0_f32]); +test!(postgres_double_vec: Vec: "ARRAY[0.0, 1.0, 3.14, 1.234, -0.002, 100000.0]::double precision[]" == vec![0.0, 1.0, 3.14, 1.234, -0.002, 100000.0_f64]); + #[cfg_attr(feature = "runtime-async-std", async_std::test)] #[cfg_attr(feature = "runtime-tokio", tokio::test)] async fn test_prepared_structs() -> anyhow::Result<()> {