Add PgBindIter for encoding and use it as the implementation encoding &[T] (#3651)

* Add PgBindIter for encoding and use it as the implementation encoding &[T]

* Implement suggestions from review

* Add docs to PgBindIter and test to ensure it works for owned and borrowed types

* Use extension trait for iterators to allow code to flow better. Make struct private. Don't reference unneeded generic T. Make doc tests compile.

* Fix doc function

* Fix doc test to actually compile

* Use Cell<Option<I>> instead of Clone bound
This commit is contained in:
Tyler Hawkes 2025-07-04 18:59:38 -06:00 committed by GitHub
parent 60f67dbc39
commit 97de03482c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 215 additions and 29 deletions

View File

@ -0,0 +1,154 @@
use crate::{type_info::PgType, PgArgumentBuffer, PgHasArrayType, PgTypeInfo, Postgres};
use core::cell::Cell;
use sqlx_core::{
database::Database,
encode::{Encode, IsNull},
error::BoxDynError,
types::Type,
};
// not exported but pub because it is used in the extension trait
pub struct PgBindIter<I>(Cell<Option<I>>);
/// Iterator extension trait enabling iterators to encode arrays in Postgres.
///
/// Because of the blanket impl of `PgHasArrayType` for all references
/// we can borrow instead of needing to clone or copy in the iterators
/// and it still works
///
/// Previously, 3 separate arrays would be needed in this example which
/// requires iterating 3 times to collect items into the array and then
/// iterating over them again to encode.
///
/// This now requires only iterating over the array once for each field
/// while using less memory giving both speed and memory usage improvements
/// along with allowing much more flexibility in the underlying collection.
///
/// ```rust,no_run
/// # async fn test_bind_iter() -> Result<(), sqlx::error::BoxDynError> {
/// # use sqlx::types::chrono::{DateTime, Utc};
/// # use sqlx::Connection;
/// # fn people() -> &'static [Person] {
/// # &[]
/// # }
/// # let mut conn = <sqlx::Postgres as sqlx::Database>::Connection::connect("dummyurl").await?;
/// use sqlx::postgres::PgBindIterExt;
///
/// #[derive(sqlx::FromRow)]
/// struct Person {
/// id: i64,
/// name: String,
/// birthdate: DateTime<Utc>,
/// }
///
/// # let people: &[Person] = people();
/// sqlx::query("insert into person(id, name, birthdate) select * from unnest($1, $2, $3)")
/// .bind(people.iter().map(|p| p.id).bind_iter())
/// .bind(people.iter().map(|p| &p.name).bind_iter())
/// .bind(people.iter().map(|p| &p.birthdate).bind_iter())
/// .execute(&mut conn)
/// .await?;
///
/// # Ok(())
/// # }
/// ```
pub trait PgBindIterExt: Iterator + Sized {
fn bind_iter(self) -> PgBindIter<Self>;
}
impl<I: Iterator + Sized> PgBindIterExt for I {
fn bind_iter(self) -> PgBindIter<I> {
PgBindIter(Cell::new(Some(self)))
}
}
impl<I> Type<Postgres> for PgBindIter<I>
where
I: Iterator,
<I as Iterator>::Item: Type<Postgres> + PgHasArrayType,
{
fn type_info() -> <Postgres as Database>::TypeInfo {
<I as Iterator>::Item::array_type_info()
}
fn compatible(ty: &PgTypeInfo) -> bool {
<I as Iterator>::Item::array_compatible(ty)
}
}
impl<'q, I> PgBindIter<I>
where
I: Iterator,
<I as Iterator>::Item: Type<Postgres> + Encode<'q, Postgres>,
{
fn encode_inner(
// need ownership to iterate
mut iter: I,
buf: &mut PgArgumentBuffer,
) -> Result<IsNull, BoxDynError> {
let lower_size_hint = iter.size_hint().0;
let first = iter.next();
let type_info = first
.as_ref()
.and_then(Encode::produces)
.unwrap_or_else(<I as Iterator>::Item::type_info);
buf.extend(&1_i32.to_be_bytes()); // number of dimensions
buf.extend(&0_i32.to_be_bytes()); // flags
match type_info.0 {
PgType::DeclareWithName(name) => buf.patch_type_by_name(&name),
PgType::DeclareArrayOf(array) => buf.patch_array_type(array),
ty => {
buf.extend(&ty.oid().0.to_be_bytes());
}
}
let len_start = buf.len();
buf.extend(0_i32.to_be_bytes()); // len (unknown so far)
buf.extend(1_i32.to_be_bytes()); // lower bound
match first {
Some(first) => buf.encode(first)?,
None => return Ok(IsNull::No),
}
let mut count = 1_i32;
const MAX: usize = i32::MAX as usize - 1;
for value in (&mut iter).take(MAX) {
buf.encode(value)?;
count += 1;
}
const OVERFLOW: usize = i32::MAX as usize + 1;
if iter.next().is_some() {
let iter_size = std::cmp::max(lower_size_hint, OVERFLOW);
return Err(format!("encoded iterator is too large for Postgres: {iter_size}").into());
}
// set the length now that we know what it is.
buf[len_start..(len_start + 4)].copy_from_slice(&count.to_be_bytes());
Ok(IsNull::No)
}
}
impl<'q, I> Encode<'q, Postgres> for PgBindIter<I>
where
I: Iterator,
<I as Iterator>::Item: Type<Postgres> + Encode<'q, Postgres>,
{
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
Self::encode_inner(self.0.take().expect("PgBindIter is only used once"), buf)
}
fn encode(self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError>
where
Self: Sized,
{
Self::encode_inner(
self.0.into_inner().expect("PgBindIter is only used once"),
buf,
)
}
}

View File

@ -7,6 +7,7 @@ use crate::executor::Executor;
mod advisory_lock;
mod arguments;
mod bind_iter;
mod column;
mod connection;
mod copy;
@ -47,6 +48,7 @@ pub(crate) use sqlx_core::driver_prelude::*;
pub use advisory_lock::{PgAdvisoryLock, PgAdvisoryLockGuard, PgAdvisoryLockKey};
pub use arguments::{PgArgumentBuffer, PgArguments};
pub use bind_iter::PgBindIterExt;
pub use column::PgColumn;
pub use connection::PgConnection;
pub use copy::{PgCopyIn, PgPoolCopyExt};

View File

@ -5,7 +5,6 @@ use std::borrow::Cow;
use crate::decode::Decode;
use crate::encode::{Encode, IsNull};
use crate::error::BoxDynError;
use crate::type_info::PgType;
use crate::types::Oid;
use crate::types::Type;
use crate::{PgArgumentBuffer, PgTypeInfo, PgValueFormat, PgValueRef, Postgres};
@ -156,39 +155,14 @@ where
T: Encode<'q, Postgres> + Type<Postgres>,
{
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
let type_info = self
.first()
.and_then(Encode::produces)
.unwrap_or_else(T::type_info);
buf.extend(&1_i32.to_be_bytes()); // number of dimensions
buf.extend(&0_i32.to_be_bytes()); // flags
// element type
match type_info.0 {
PgType::DeclareWithName(name) => buf.patch_type_by_name(&name),
PgType::DeclareArrayOf(array) => buf.patch_array_type(array),
ty => {
buf.extend(&ty.oid().0.to_be_bytes());
}
}
let array_len = i32::try_from(self.len()).map_err(|_| {
// do the length check early to avoid doing unnecessary work
i32::try_from(self.len()).map_err(|_| {
format!(
"encoded array length is too large for Postgres: {}",
self.len()
)
})?;
buf.extend(array_len.to_be_bytes()); // len
buf.extend(&1_i32.to_be_bytes()); // lower bound
for element in self.iter() {
buf.encode(element)?;
}
Ok(IsNull::No)
crate::PgBindIterExt::bind_iter(self.iter()).encode(buf)
}
}

View File

@ -2069,6 +2069,62 @@ async fn test_issue_3052() {
}
#[sqlx_macros::test]
async fn test_bind_iter() -> anyhow::Result<()> {
use sqlx::postgres::PgBindIterExt;
use sqlx::types::chrono::{DateTime, Utc};
let mut conn = new::<Postgres>().await?;
#[derive(sqlx::FromRow, PartialEq, Debug)]
struct Person {
id: i64,
name: String,
birthdate: DateTime<Utc>,
}
let people: Vec<Person> = vec![
Person {
id: 1,
name: "Alice".into(),
birthdate: "1984-01-01T00:00:00Z".parse().unwrap(),
},
Person {
id: 2,
name: "Bob".into(),
birthdate: "2000-01-01T00:00:00Z".parse().unwrap(),
},
];
sqlx::query(
r#"
create temporary table person(
id int8 primary key,
name text not null,
birthdate timestamptz not null
)"#,
)
.execute(&mut conn)
.await?;
let rows_affected =
sqlx::query("insert into person(id, name, birthdate) select * from unnest($1, $2, $3)")
// owned value
.bind(people.iter().map(|p| p.id).bind_iter())
// borrowed value
.bind(people.iter().map(|p| &p.name).bind_iter())
.bind(people.iter().map(|p| &p.birthdate).bind_iter())
.execute(&mut conn)
.await?
.rows_affected();
assert_eq!(rows_affected, 2);
let p_query = sqlx::query_as::<_, Person>("select * from person order by id")
.fetch_all(&mut conn)
.await?;
assert_eq!(people, p_query);
Ok(())
}
async fn test_pg_copy_chunked() -> anyhow::Result<()> {
let mut conn = new::<Postgres>().await?;