sqlx/sqlx-postgres/src/arguments.rs
Max Bruckner c57b46ceb6
Make Encode return a result (#3126)
* Make encode and encode_by_ref fallible

This only changes the trait for now and makes it compile, calling .expect() on all users. Those will be removed in a later commit.

* PgNumeric: Turn TryFrom Decimal to an infallible From

* Turn panics in Encode implementations into errors

* Add Encode error analogous to the Decode error

* Propagate decode errors through Arguments::add

This pushes the panics one level further to mostly bind calls. Those will also be removed later.

* Only check argument encoding at the end

* Use Result in Query internally

* Implement query_with functions in terms of _with_result

* Surface encode errors when executing a query.

* Remove remaining panics in AnyConnectionBackend implementations

* PostgreSQL BigDecimal: Return encode error immediately

* Arguments: Add len method to report how many arguments were added

* Query::bind: Report which argument failed to encode

* IsNull: Add is_null method

* MySqlArguments: Replace manual bitmap code with NullBitMap helper type

* Roll back buffer in MySqlArguments if encoding fails

* Roll back buffer in SqliteArguments if encoding fails

* Roll back PgArgumentBuffer if encoding fails
2024-05-31 12:42:36 -07:00

240 lines
7.1 KiB
Rust

use std::fmt::{self, Write};
use std::ops::{Deref, DerefMut};
use crate::encode::{Encode, IsNull};
use crate::error::Error;
use crate::ext::ustr::UStr;
use crate::types::Type;
use crate::{PgConnection, PgTypeInfo, Postgres};
pub(crate) use sqlx_core::arguments::Arguments;
use sqlx_core::error::BoxDynError;
// TODO: buf.patch(|| ...) is a poor name, can we think of a better name? Maybe `buf.lazy(||)` ?
// TODO: Extend the patch system to support dynamic lengths
// Considerations:
// - The prefixed-len offset needs to be back-tracked and updated
// - message::Bind needs to take a &PgArguments and use a `write` method instead of
// referencing a buffer directly
// - The basic idea is that we write bytes for the buffer until we get somewhere
// that has a patch, we then apply the patch which should write to &mut Vec<u8>,
// backtrack and update the prefixed-len, then write until the next patch offset
#[derive(Default)]
pub struct PgArgumentBuffer {
buffer: Vec<u8>,
// Number of arguments
count: usize,
// Whenever an `Encode` impl needs to defer some work until after we resolve parameter types
// it can use `patch`.
//
// This currently is only setup to be useful if there is a *fixed-size* slot that needs to be
// tweaked from the input type. However, that's the only use case we currently have.
//
patches: Vec<(
usize, // offset
usize, // argument index
Box<dyn Fn(&mut [u8], &PgTypeInfo) + 'static + Send + Sync>,
)>,
// Whenever an `Encode` impl encounters a `PgTypeInfo` object that does not have an OID
// It pushes a "hole" that must be patched later.
//
// The hole is a `usize` offset into the buffer with the type name that should be resolved
// This is done for Records and Arrays as the OID is needed well before we are in an async
// function and can just ask postgres.
//
type_holes: Vec<(usize, UStr)>, // Vec<{ offset, type_name }>
}
/// Implementation of [`Arguments`] for PostgreSQL.
#[derive(Default)]
pub struct PgArguments {
// Types of each bind parameter
pub(crate) types: Vec<PgTypeInfo>,
// Buffer of encoded bind parameters
pub(crate) buffer: PgArgumentBuffer,
}
impl PgArguments {
pub(crate) fn add<'q, T>(&mut self, value: T) -> Result<(), BoxDynError>
where
T: Encode<'q, Postgres> + Type<Postgres>,
{
let type_info = value.produces().unwrap_or_else(T::type_info);
let buffer_snapshot = self.buffer.snapshot();
// encode the value into our buffer
if let Err(error) = self.buffer.encode(value) {
// reset the value buffer to its previous value if encoding failed so we don't leave a half-encoded value behind
self.buffer.reset_to_snapshot(buffer_snapshot);
return Err(error);
};
// remember the type information for this value
self.types.push(type_info);
// increment the number of arguments we are tracking
self.buffer.count += 1;
Ok(())
}
// Apply patches
// This should only go out and ask postgres if we have not seen the type name yet
pub(crate) async fn apply_patches(
&mut self,
conn: &mut PgConnection,
parameters: &[PgTypeInfo],
) -> Result<(), Error> {
let PgArgumentBuffer {
ref patches,
ref type_holes,
ref mut buffer,
..
} = self.buffer;
for (offset, ty, callback) in patches {
let buf = &mut buffer[*offset..];
let ty = &parameters[*ty];
callback(buf, ty);
}
for (offset, name) in type_holes {
let oid = conn.fetch_type_id_by_name(&*name).await?;
buffer[*offset..(*offset + 4)].copy_from_slice(&oid.0.to_be_bytes());
}
Ok(())
}
}
impl<'q> Arguments<'q> for PgArguments {
type Database = Postgres;
fn reserve(&mut self, additional: usize, size: usize) {
self.types.reserve(additional);
self.buffer.reserve(size);
}
fn add<T>(&mut self, value: T) -> Result<(), BoxDynError>
where
T: Encode<'q, Self::Database> + Type<Self::Database>,
{
self.add(value)
}
fn format_placeholder<W: Write>(&self, writer: &mut W) -> fmt::Result {
write!(writer, "${}", self.buffer.count)
}
fn len(&self) -> usize {
self.buffer.count
}
}
impl PgArgumentBuffer {
pub(crate) fn encode<'q, T>(&mut self, value: T) -> Result<(), BoxDynError>
where
T: Encode<'q, Postgres>,
{
// reserve space to write the prefixed length of the value
let offset = self.len();
self.extend(&[0; 4]);
// encode the value into our buffer
let len = if let IsNull::No = value.encode(self)? {
(self.len() - offset - 4) as i32
} else {
// Write a -1 to indicate NULL
// NOTE: It is illegal for [encode] to write any data
debug_assert_eq!(self.len(), offset + 4);
-1_i32
};
// write the len to the beginning of the value
self[offset..(offset + 4)].copy_from_slice(&len.to_be_bytes());
Ok(())
}
// Adds a callback to be invoked later when we know the parameter type
#[allow(dead_code)]
pub(crate) fn patch<F>(&mut self, callback: F)
where
F: Fn(&mut [u8], &PgTypeInfo) + 'static + Send + Sync,
{
let offset = self.len();
let index = self.count;
self.patches.push((offset, index, Box::new(callback)));
}
// Extends the inner buffer by enough space to have an OID
// Remembers where the OID goes and type name for the OID
pub(crate) fn patch_type_by_name(&mut self, type_name: &UStr) {
let offset = self.len();
self.extend_from_slice(&0_u32.to_be_bytes());
self.type_holes.push((offset, type_name.clone()));
}
fn snapshot(&self) -> PgArgumentBufferSnapshot {
let Self {
buffer,
count,
patches,
type_holes,
} = self;
PgArgumentBufferSnapshot {
buffer_length: buffer.len(),
count: *count,
patches_length: patches.len(),
type_holes_length: type_holes.len(),
}
}
fn reset_to_snapshot(
&mut self,
PgArgumentBufferSnapshot {
buffer_length,
count,
patches_length,
type_holes_length,
}: PgArgumentBufferSnapshot,
) {
self.buffer.truncate(buffer_length);
self.count = count;
self.patches.truncate(patches_length);
self.type_holes.truncate(type_holes_length);
}
}
struct PgArgumentBufferSnapshot {
buffer_length: usize,
count: usize,
patches_length: usize,
type_holes_length: usize,
}
impl Deref for PgArgumentBuffer {
type Target = Vec<u8>;
#[inline]
fn deref(&self) -> &Self::Target {
&self.buffer
}
}
impl DerefMut for PgArgumentBuffer {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.buffer
}
}