support arbitrary numbers of bind parameters in query!() et al

This commit is contained in:
Austin Bonander 2020-01-15 00:03:05 -08:00
parent da5c538d22
commit 0fdb875c20
10 changed files with 66 additions and 124 deletions

View File

@ -38,123 +38,22 @@ where
fn into_arguments(self) -> DB::Arguments;
}
impl<DB> IntoArguments<DB> for DB::Arguments
impl<A> IntoArguments<A::Database> for A
where
DB: Database,
A: Arguments,
A::Database: Database<Arguments = Self> + Sized,
{
#[inline]
fn into_arguments(self) -> DB::Arguments {
fn into_arguments(self) -> Self {
self
}
}
#[allow(unused)]
macro_rules! impl_into_arguments {
($B:ident: $( ($idx:tt) -> $T:ident );+;) => {
impl<$($T,)+> crate::arguments::IntoArguments<$B> for ($($T,)+)
where
$($B: crate::types::HasSqlType<$T>,)+
$($T: crate::encode::Encode<$B>,)+
{
fn into_arguments(self) -> <$B as crate::database::Database>::Arguments {
use crate::arguments::Arguments;
#[doc(hidden)]
pub struct ImmutableArguments<DB: Database>(pub DB::Arguments);
let mut arguments = <$B as crate::database::Database>::Arguments::default();
let binds = 0 $(+ { $idx; 1 } )+;
let bytes = 0 $(+ crate::encode::Encode::size_hint(&self.$idx))+;
arguments.reserve(binds, bytes);
$(crate::arguments::Arguments::add(&mut arguments, self.$idx);)+
arguments
}
}
};
}
#[allow(unused)]
macro_rules! impl_into_arguments_for_database {
($B:ident) => {
impl crate::arguments::IntoArguments<$B> for ()
{
#[inline]
fn into_arguments(self) -> <$B as crate::database::Database>::Arguments {
Default::default()
}
}
impl_into_arguments!($B:
(0) -> T1;
);
impl_into_arguments!($B:
(0) -> T1;
(1) -> T2;
);
impl_into_arguments!($B:
(0) -> T1;
(1) -> T2;
(2) -> T3;
);
impl_into_arguments!($B:
(0) -> T1;
(1) -> T2;
(2) -> T3;
(3) -> T4;
);
impl_into_arguments!($B:
(0) -> T1;
(1) -> T2;
(2) -> T3;
(3) -> T4;
(4) -> T5;
);
impl_into_arguments!($B:
(0) -> T1;
(1) -> T2;
(2) -> T3;
(3) -> T4;
(4) -> T5;
(5) -> T6;
);
impl_into_arguments!($B:
(0) -> T1;
(1) -> T2;
(2) -> T3;
(3) -> T4;
(4) -> T5;
(5) -> T6;
(6) -> T7;
);
impl_into_arguments!($B:
(0) -> T1;
(1) -> T2;
(2) -> T3;
(3) -> T4;
(4) -> T5;
(5) -> T6;
(6) -> T7;
(7) -> T8;
);
impl_into_arguments!($B:
(0) -> T1;
(1) -> T2;
(2) -> T3;
(3) -> T4;
(4) -> T5;
(5) -> T6;
(6) -> T7;
(7) -> T8;
(8) -> T9;
);
impl<DB: Database> IntoArguments<DB> for ImmutableArguments<DB> {
fn into_arguments(self) -> <DB as Database>::Arguments {
self.0
}
}

View File

@ -14,5 +14,3 @@ impl Database for MySql {
type TableId = Box<str>;
}
impl_into_arguments_for_database!(MySql);

View File

@ -1,6 +1,6 @@
use byteorder::{ByteOrder, NetworkEndian};
use crate::arguments::Arguments;
use crate::arguments::{Arguments, IntoArguments};
use crate::encode::{Encode, IsNull};
use crate::io::BufMut;
use crate::types::HasSqlType;

View File

@ -14,5 +14,3 @@ impl Database for Postgres {
type TableId = u32;
}
impl_into_arguments_for_database!(Postgres);

View File

@ -1,7 +1,7 @@
use futures_core::Stream;
use futures_util::{future, TryStreamExt};
use crate::arguments::Arguments;
use crate::arguments::{Arguments, ImmutableArguments};
use crate::{
arguments::IntoArguments, database::Database, encode::Encode, executor::Executor, row::FromRow,
types::HasSqlType,
@ -128,13 +128,10 @@ where
// used by query!() and friends
#[doc(hidden)]
pub fn bind_all<I>(self, values: I) -> QueryAs<'q, DB, R, I>
where
I: IntoArguments<DB>,
{
pub fn bind_all(self, values: DB::Arguments) -> QueryAs<'q, DB, R, ImmutableArguments<DB>> {
QueryAs {
query: self.query,
args: values,
args: ImmutableArguments(values),
map_row: self.map_row,
}
}

View File

@ -85,7 +85,7 @@ macro_rules! async_macro (
macro_result(parse_err.to_compile_error())
} else {
let msg = format!("{:?}", e);
macro_result(quote!(compile_error(#msg)))
macro_result(quote!(compile_error!(#msg)))
}
}
}

View File

@ -59,6 +59,7 @@ pub fn quote_args<DB: DatabaseExt>(
let args = input.arg_names.iter();
Ok(quote! {
// emit as a tuple first so each expression is only evaluated once
let args = (#(&$#args),*,);
#args_check
})

View File

@ -55,11 +55,26 @@ where
&columns,
);
let db_path = <C::Database as DatabaseExt>::quotable_path();
let args_count = arg_names.len();
let arg_indices = (0..args_count).map(|i| syn::Index::from(i));
let arg_indices_2 = arg_indices.clone();
Ok(quote! {
macro_rules! macro_result {
(#($#arg_names:expr),*) => {{
use sqlx::arguments::Arguments as _;
#args_tokens
#output.bind_all(args)
let mut query_args = <#db_path as sqlx::Database>::Arguments::default();
query_args.reserve(
#args_count,
0 #(+ sqlx::encode::Encode::<#db_path>::size_hint(args.#arg_indices))*
);
#(query_args.add(args.#arg_indices_2);)*
#output.bind_all(query_args)
}}
}
})

View File

@ -51,10 +51,16 @@ where
let output = output::quote_query_as::<C::Database>(sql, &record_type, &columns);
let arg_names = &input.arg_names;
let args_count = arg_names.len();
let arg_indices = (0..args_count).map(|i| syn::Index::from(i));
let arg_indices_2 = arg_indices.clone();
let db_path = <C::Database as DatabaseExt>::quotable_path();
Ok(quote! {
macro_rules! macro_result {
(#($#arg_names:expr),*) => {{
use sqlx::arguments::Arguments as _;
#[derive(Debug)]
struct #record_type {
#record_fields
@ -62,7 +68,15 @@ where
#args
#output.bind_all(args)
let mut query_args = <#db_path as sqlx::Database>::Arguments::default();
query_args.reserve(
#args_count,
0 #(+ sqlx::encode::Encode::<#db_path>::size_hint(args.#arg_indices))*
);
#(query_args.add(args.#arg_indices_2);)*
#output.bind_all(query_args)
}
}}
})

View File

@ -108,3 +108,23 @@ async fn test_nullable_err() -> sqlx::Result<()> {
panic!("expected `UnexpectedNull`, got {}", err)
}
}
#[async_std::test]
async fn test_many_args() -> sqlx::Result<()> {
let mut conn = sqlx::postgres::connect(&dotenv::var("DATABASE_URL").unwrap()).await?;
// previous implementation would only have supported 10 bind parameters
// (this is really gross to test in MySQL)
let rows = sqlx::query!(
"SELECT * from unnest(array[$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12]::int[]) ids(id);",
0i32, 1i32, 2i32, 3i32, 4i32, 5i32, 6i32, 7i32, 8i32, 9i32, 10i32, 11i32
)
.fetch_all(&mut conn)
.await?;
for (i, row) in rows.iter().enumerate() {
assert_eq!(i as i32, row.id);
}
Ok(())
}