mirror of
https://github.com/launchbadge/sqlx.git
synced 2025-10-02 15:25:32 +00:00
Add PgBindIter for encoding and use it as the implementation encoding &[T] (#3651)
* Add PgBindIter for encoding and use it as the implementation encoding &[T] * Implement suggestions from review * Add docs to PgBindIter and test to ensure it works for owned and borrowed types * Use extension trait for iterators to allow code to flow better. Make struct private. Don't reference unneeded generic T. Make doc tests compile. * Fix doc function * Fix doc test to actually compile * Use Cell<Option<I>> instead of Clone bound
This commit is contained in:
parent
60f67dbc39
commit
97de03482c
154
sqlx-postgres/src/bind_iter.rs
Normal file
154
sqlx-postgres/src/bind_iter.rs
Normal file
@ -0,0 +1,154 @@
|
||||
use crate::{type_info::PgType, PgArgumentBuffer, PgHasArrayType, PgTypeInfo, Postgres};
|
||||
use core::cell::Cell;
|
||||
use sqlx_core::{
|
||||
database::Database,
|
||||
encode::{Encode, IsNull},
|
||||
error::BoxDynError,
|
||||
types::Type,
|
||||
};
|
||||
|
||||
// not exported but pub because it is used in the extension trait
|
||||
pub struct PgBindIter<I>(Cell<Option<I>>);
|
||||
|
||||
/// Iterator extension trait enabling iterators to encode arrays in Postgres.
|
||||
///
|
||||
/// Because of the blanket impl of `PgHasArrayType` for all references
|
||||
/// we can borrow instead of needing to clone or copy in the iterators
|
||||
/// and it still works
|
||||
///
|
||||
/// Previously, 3 separate arrays would be needed in this example which
|
||||
/// requires iterating 3 times to collect items into the array and then
|
||||
/// iterating over them again to encode.
|
||||
///
|
||||
/// This now requires only iterating over the array once for each field
|
||||
/// while using less memory giving both speed and memory usage improvements
|
||||
/// along with allowing much more flexibility in the underlying collection.
|
||||
///
|
||||
/// ```rust,no_run
|
||||
/// # async fn test_bind_iter() -> Result<(), sqlx::error::BoxDynError> {
|
||||
/// # use sqlx::types::chrono::{DateTime, Utc};
|
||||
/// # use sqlx::Connection;
|
||||
/// # fn people() -> &'static [Person] {
|
||||
/// # &[]
|
||||
/// # }
|
||||
/// # let mut conn = <sqlx::Postgres as sqlx::Database>::Connection::connect("dummyurl").await?;
|
||||
/// use sqlx::postgres::PgBindIterExt;
|
||||
///
|
||||
/// #[derive(sqlx::FromRow)]
|
||||
/// struct Person {
|
||||
/// id: i64,
|
||||
/// name: String,
|
||||
/// birthdate: DateTime<Utc>,
|
||||
/// }
|
||||
///
|
||||
/// # let people: &[Person] = people();
|
||||
/// sqlx::query("insert into person(id, name, birthdate) select * from unnest($1, $2, $3)")
|
||||
/// .bind(people.iter().map(|p| p.id).bind_iter())
|
||||
/// .bind(people.iter().map(|p| &p.name).bind_iter())
|
||||
/// .bind(people.iter().map(|p| &p.birthdate).bind_iter())
|
||||
/// .execute(&mut conn)
|
||||
/// .await?;
|
||||
///
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub trait PgBindIterExt: Iterator + Sized {
|
||||
fn bind_iter(self) -> PgBindIter<Self>;
|
||||
}
|
||||
|
||||
impl<I: Iterator + Sized> PgBindIterExt for I {
|
||||
fn bind_iter(self) -> PgBindIter<I> {
|
||||
PgBindIter(Cell::new(Some(self)))
|
||||
}
|
||||
}
|
||||
|
||||
impl<I> Type<Postgres> for PgBindIter<I>
|
||||
where
|
||||
I: Iterator,
|
||||
<I as Iterator>::Item: Type<Postgres> + PgHasArrayType,
|
||||
{
|
||||
fn type_info() -> <Postgres as Database>::TypeInfo {
|
||||
<I as Iterator>::Item::array_type_info()
|
||||
}
|
||||
fn compatible(ty: &PgTypeInfo) -> bool {
|
||||
<I as Iterator>::Item::array_compatible(ty)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'q, I> PgBindIter<I>
|
||||
where
|
||||
I: Iterator,
|
||||
<I as Iterator>::Item: Type<Postgres> + Encode<'q, Postgres>,
|
||||
{
|
||||
fn encode_inner(
|
||||
// need ownership to iterate
|
||||
mut iter: I,
|
||||
buf: &mut PgArgumentBuffer,
|
||||
) -> Result<IsNull, BoxDynError> {
|
||||
let lower_size_hint = iter.size_hint().0;
|
||||
let first = iter.next();
|
||||
let type_info = first
|
||||
.as_ref()
|
||||
.and_then(Encode::produces)
|
||||
.unwrap_or_else(<I as Iterator>::Item::type_info);
|
||||
|
||||
buf.extend(&1_i32.to_be_bytes()); // number of dimensions
|
||||
buf.extend(&0_i32.to_be_bytes()); // flags
|
||||
|
||||
match type_info.0 {
|
||||
PgType::DeclareWithName(name) => buf.patch_type_by_name(&name),
|
||||
PgType::DeclareArrayOf(array) => buf.patch_array_type(array),
|
||||
|
||||
ty => {
|
||||
buf.extend(&ty.oid().0.to_be_bytes());
|
||||
}
|
||||
}
|
||||
|
||||
let len_start = buf.len();
|
||||
buf.extend(0_i32.to_be_bytes()); // len (unknown so far)
|
||||
buf.extend(1_i32.to_be_bytes()); // lower bound
|
||||
|
||||
match first {
|
||||
Some(first) => buf.encode(first)?,
|
||||
None => return Ok(IsNull::No),
|
||||
}
|
||||
|
||||
let mut count = 1_i32;
|
||||
const MAX: usize = i32::MAX as usize - 1;
|
||||
|
||||
for value in (&mut iter).take(MAX) {
|
||||
buf.encode(value)?;
|
||||
count += 1;
|
||||
}
|
||||
|
||||
const OVERFLOW: usize = i32::MAX as usize + 1;
|
||||
if iter.next().is_some() {
|
||||
let iter_size = std::cmp::max(lower_size_hint, OVERFLOW);
|
||||
return Err(format!("encoded iterator is too large for Postgres: {iter_size}").into());
|
||||
}
|
||||
|
||||
// set the length now that we know what it is.
|
||||
buf[len_start..(len_start + 4)].copy_from_slice(&count.to_be_bytes());
|
||||
|
||||
Ok(IsNull::No)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'q, I> Encode<'q, Postgres> for PgBindIter<I>
|
||||
where
|
||||
I: Iterator,
|
||||
<I as Iterator>::Item: Type<Postgres> + Encode<'q, Postgres>,
|
||||
{
|
||||
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
|
||||
Self::encode_inner(self.0.take().expect("PgBindIter is only used once"), buf)
|
||||
}
|
||||
fn encode(self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
Self::encode_inner(
|
||||
self.0.into_inner().expect("PgBindIter is only used once"),
|
||||
buf,
|
||||
)
|
||||
}
|
||||
}
|
@ -7,6 +7,7 @@ use crate::executor::Executor;
|
||||
|
||||
mod advisory_lock;
|
||||
mod arguments;
|
||||
mod bind_iter;
|
||||
mod column;
|
||||
mod connection;
|
||||
mod copy;
|
||||
@ -47,6 +48,7 @@ pub(crate) use sqlx_core::driver_prelude::*;
|
||||
|
||||
pub use advisory_lock::{PgAdvisoryLock, PgAdvisoryLockGuard, PgAdvisoryLockKey};
|
||||
pub use arguments::{PgArgumentBuffer, PgArguments};
|
||||
pub use bind_iter::PgBindIterExt;
|
||||
pub use column::PgColumn;
|
||||
pub use connection::PgConnection;
|
||||
pub use copy::{PgCopyIn, PgPoolCopyExt};
|
||||
|
@ -5,7 +5,6 @@ use std::borrow::Cow;
|
||||
use crate::decode::Decode;
|
||||
use crate::encode::{Encode, IsNull};
|
||||
use crate::error::BoxDynError;
|
||||
use crate::type_info::PgType;
|
||||
use crate::types::Oid;
|
||||
use crate::types::Type;
|
||||
use crate::{PgArgumentBuffer, PgTypeInfo, PgValueFormat, PgValueRef, Postgres};
|
||||
@ -156,39 +155,14 @@ where
|
||||
T: Encode<'q, Postgres> + Type<Postgres>,
|
||||
{
|
||||
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
|
||||
let type_info = self
|
||||
.first()
|
||||
.and_then(Encode::produces)
|
||||
.unwrap_or_else(T::type_info);
|
||||
|
||||
buf.extend(&1_i32.to_be_bytes()); // number of dimensions
|
||||
buf.extend(&0_i32.to_be_bytes()); // flags
|
||||
|
||||
// element type
|
||||
match type_info.0 {
|
||||
PgType::DeclareWithName(name) => buf.patch_type_by_name(&name),
|
||||
PgType::DeclareArrayOf(array) => buf.patch_array_type(array),
|
||||
|
||||
ty => {
|
||||
buf.extend(&ty.oid().0.to_be_bytes());
|
||||
}
|
||||
}
|
||||
|
||||
let array_len = i32::try_from(self.len()).map_err(|_| {
|
||||
// do the length check early to avoid doing unnecessary work
|
||||
i32::try_from(self.len()).map_err(|_| {
|
||||
format!(
|
||||
"encoded array length is too large for Postgres: {}",
|
||||
self.len()
|
||||
)
|
||||
})?;
|
||||
|
||||
buf.extend(array_len.to_be_bytes()); // len
|
||||
buf.extend(&1_i32.to_be_bytes()); // lower bound
|
||||
|
||||
for element in self.iter() {
|
||||
buf.encode(element)?;
|
||||
}
|
||||
|
||||
Ok(IsNull::No)
|
||||
crate::PgBindIterExt::bind_iter(self.iter()).encode(buf)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2069,6 +2069,62 @@ async fn test_issue_3052() {
|
||||
}
|
||||
|
||||
#[sqlx_macros::test]
|
||||
async fn test_bind_iter() -> anyhow::Result<()> {
|
||||
use sqlx::postgres::PgBindIterExt;
|
||||
use sqlx::types::chrono::{DateTime, Utc};
|
||||
|
||||
let mut conn = new::<Postgres>().await?;
|
||||
|
||||
#[derive(sqlx::FromRow, PartialEq, Debug)]
|
||||
struct Person {
|
||||
id: i64,
|
||||
name: String,
|
||||
birthdate: DateTime<Utc>,
|
||||
}
|
||||
|
||||
let people: Vec<Person> = vec![
|
||||
Person {
|
||||
id: 1,
|
||||
name: "Alice".into(),
|
||||
birthdate: "1984-01-01T00:00:00Z".parse().unwrap(),
|
||||
},
|
||||
Person {
|
||||
id: 2,
|
||||
name: "Bob".into(),
|
||||
birthdate: "2000-01-01T00:00:00Z".parse().unwrap(),
|
||||
},
|
||||
];
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
create temporary table person(
|
||||
id int8 primary key,
|
||||
name text not null,
|
||||
birthdate timestamptz not null
|
||||
)"#,
|
||||
)
|
||||
.execute(&mut conn)
|
||||
.await?;
|
||||
|
||||
let rows_affected =
|
||||
sqlx::query("insert into person(id, name, birthdate) select * from unnest($1, $2, $3)")
|
||||
// owned value
|
||||
.bind(people.iter().map(|p| p.id).bind_iter())
|
||||
// borrowed value
|
||||
.bind(people.iter().map(|p| &p.name).bind_iter())
|
||||
.bind(people.iter().map(|p| &p.birthdate).bind_iter())
|
||||
.execute(&mut conn)
|
||||
.await?
|
||||
.rows_affected();
|
||||
assert_eq!(rows_affected, 2);
|
||||
|
||||
let p_query = sqlx::query_as::<_, Person>("select * from person order by id")
|
||||
.fetch_all(&mut conn)
|
||||
.await?;
|
||||
|
||||
assert_eq!(people, p_query);
|
||||
Ok(())
|
||||
}
|
||||
async fn test_pg_copy_chunked() -> anyhow::Result<()> {
|
||||
let mut conn = new::<Postgres>().await?;
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user