sqlx/sqlx-core/src/arguments.rs

212 lines
6.8 KiB
Rust

use std::any;
use either::Either;
use crate::database::HasOutput;
use crate::{encode, Database, Error, Result, TypeEncode, TypeInfo};
/// A collection of arguments to be applied to a prepared statement.
///
/// This container allows for a heterogeneous list of positional and named
/// arguments to be collected before executing the query.
///
/// The [`Query`] object uses an internal `Arguments` collection.
///
pub struct Arguments<'a, Db: Database> {
named: Vec<Argument<'a, Db>>,
positional: Vec<Argument<'a, Db>>,
}
/// A single argument to be applied to a prepared statement.
pub struct Argument<'a, Db: Database> {
unchecked: bool,
parameter: Either<usize, &'a str>,
// preserved from `T::type_id()`
type_id: Db::TypeId,
// preserved from `T::compatible`
type_compatible: fn(&Db::TypeInfo) -> bool,
// preserved from `any::type_name::<T>`
// used in error messages
rust_type_name: &'static str,
// TODO: we might want to allow binding to Box<dyn TypeEncode<Db>>
// this would allow an Owned storage of values
value: &'a dyn TypeEncode<Db>,
}
impl<'a, Db: Database> Argument<'a, Db> {
fn new<T: 'a + TypeEncode<Db>>(
parameter: Either<usize, &'a str>,
value: &'a T,
unchecked: bool,
) -> Self {
Self {
value,
unchecked,
parameter,
type_id: T::type_id(),
type_compatible: T::compatible,
rust_type_name: any::type_name::<T>(),
}
}
}
impl<Db: Database> Default for Arguments<'_, Db> {
fn default() -> Self {
Self { named: Vec::new(), positional: Vec::new() }
}
}
impl<'a, Db: Database> Arguments<'a, Db> {
/// Creates an empty `Arguments`.
#[must_use]
pub fn new() -> Self {
Self::default()
}
/// Add a value to the end of the arguments collection.
///
/// When the argument is applied to a prepared statement, its type will be checked
/// for compatibility against the expected type from the database. As an example, given a
/// SQL expression such as `SELECT * FROM table WHERE field = {}`, if `field` is an integer type
/// and you attempt to bind a `&str` in Rust, an incompatible type error will be raised.
///
pub fn add<T: 'a + TypeEncode<Db>>(&mut self, value: &'a T) {
let index = self.positional.len();
self.positional.push(Argument::new(Either::Left(index), value, false));
}
/// Add an unchecked value to the end of the arguments collection.
///
/// When the argument is applied to a prepared statement, its type will not be checked
/// against the expected type from the database. Further, in PostgreSQL, the argument type
/// will not be hinted when preparing the statement.
///
pub fn add_unchecked<T: 'a + TypeEncode<Db>>(&mut self, value: &'a T) {
let index = self.positional.len();
self.positional.push(Argument::new(Either::Left(index), value, true));
}
/// Add a named value to the argument collection.
pub fn add_as<T: 'a + TypeEncode<Db>>(&mut self, name: &'a str, value: &'a T) {
self.named.push(Argument::new(Either::Right(name), value, false));
}
/// Add an unchecked, named value to the arguments collection.
pub fn add_unchecked_as<T: 'a + TypeEncode<Db>>(&mut self, name: &'a str, value: &'a T) {
self.named.push(Argument::new(Either::Right(name), value, true));
}
}
impl<'a, Db: Database> Arguments<'a, Db> {
/// Reserves capacity for at least `additional` more positional parameters.
pub fn reserve_positional(&mut self, additional: usize) {
self.positional.reserve(additional);
}
/// Reserves capacity for at least `additional` more named parameters.
pub fn reserve_named(&mut self, additional: usize) {
self.named.reserve(additional);
}
/// Returns the number of positional and named parameters.
#[must_use]
pub fn len(&self) -> usize {
self.num_named() + self.num_positional()
}
/// Returns `true` if there are no positional or named parameters.
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
/// Clears the `Arguments`, removing all values.
pub fn clear(&mut self) {
self.named.clear();
self.positional.clear();
}
/// Returns the number of named parameters.
#[must_use]
pub fn num_named(&self) -> usize {
self.named.len()
}
/// Returns the number of positional parameters.
#[must_use]
pub fn num_positional(&self) -> usize {
self.positional.len()
}
/// Returns an iterator of the positional parameters.
pub fn positional(&self) -> impl Iterator<Item = &Argument<'a, Db>> {
self.positional.iter()
}
/// Returns a reference to the argument at the location, if present.
pub fn get<'x, I: ArgumentIndex<'a, Db>>(&'x self, index: &I) -> Option<&'x Argument<'a, Db>> {
index.get(self)
}
}
impl<'a, Db: Database> Argument<'a, Db> {
/// Returns the SQL type identifier of the argument.
#[must_use]
pub fn type_id(&self) -> Db::TypeId {
self.type_id
}
/// Encode this argument into the output buffer, for use in executing
/// the prepared statement.
///
/// When the statement is prepared, the database will often infer the type
/// of the incoming argument. This method takes that (`ty`) along with the value of
/// the argument to encode into the output buffer.
///
pub fn encode<'x>(
&self,
ty: &Db::TypeInfo,
out: &mut <Db as HasOutput<'x>>::Output,
) -> Result<()> {
let res = if !self.unchecked && !(self.type_compatible)(ty) {
Err(encode::Error::TypeNotCompatible {
rust_type_name: self.rust_type_name,
sql_type_name: ty.name(),
})
} else {
self.value.encode(ty, out)
};
res.map_err(|source| Error::ParameterEncode {
parameter: self.parameter.map_right(|name| name.to_string().into_boxed_str()),
source,
})
}
}
/// A helper trait used for indexing into an [`Arguments`] collection.
pub trait ArgumentIndex<'a, Db: Database> {
/// Returns a reference to the argument at this location, if present.
fn get<'x>(&self, arguments: &'x Arguments<'a, Db>) -> Option<&'x Argument<'a, Db>>;
}
// access a named argument by name
impl<'a, Db: Database> ArgumentIndex<'a, Db> for str {
fn get<'x>(&self, arguments: &'x Arguments<'a, Db>) -> Option<&'x Argument<'a, Db>> {
arguments.named.iter().find_map(|arg| (arg.parameter.right() == Some(self)).then(|| arg))
}
}
// access a positional argument by index
impl<'a, Db: Database> ArgumentIndex<'a, Db> for usize {
fn get<'x>(&self, arguments: &'x Arguments<'a, Db>) -> Option<&'x Argument<'a, Db>> {
arguments.positional.get(*self)
}
}