make query!() output anonymous records

This commit is contained in:
Austin Bonander 2019-12-03 15:31:49 -08:00
parent 871183d23b
commit acca40c88e
14 changed files with 301 additions and 219 deletions

View File

@ -32,6 +32,7 @@ proc-macro-hack = { version = "0.5.11", optional = true }
[dev-dependencies]
async-std = { version = "1.2.0", features = [ "attributes" ] }
dotenv = "0.15.0"
matches = "0.1.8"
criterion = "0.3.0"

View File

@ -52,26 +52,26 @@ where
self.live.execute(query, params)
}
fn fetch<'e, 'q: 'e, I: 'e, O: 'e, T: 'e>(
fn fetch<'e, 'q: 'e, I: 'e, T: 'e>(
&'e mut self,
query: &'q str,
params: I,
) -> BoxStream<'e, crate::Result<T>>
where
I: IntoQueryParameters<Self::Backend> + Send,
T: FromRow<Self::Backend, O> + Send + Unpin,
T: FromRow<Self::Backend> + Send + Unpin,
{
self.live.fetch(query, params)
}
fn fetch_optional<'e, 'q: 'e, I: 'e, O: 'e, T: 'e>(
fn fetch_optional<'e, 'q: 'e, I: 'e, T: 'e>(
&'e mut self,
query: &'q str,
params: I,
) -> BoxFuture<'e, crate::Result<Option<T>>>
where
I: IntoQueryParameters<Self::Backend> + Send,
T: FromRow<Self::Backend, O> + Send,
T: FromRow<Self::Backend> + Send,
{
self.live.fetch_optional(query, params)
}

View File

@ -30,44 +30,44 @@ pub trait Executor: Send {
where
I: IntoQueryParameters<Self::Backend> + Send;
fn fetch<'e, 'q: 'e, I: 'e, O: 'e, T: 'e>(
fn fetch<'e, 'q: 'e, I: 'e, T: 'e>(
&'e mut self,
query: &'q str,
params: I,
) -> BoxStream<'e, crate::Result<T>>
where
I: IntoQueryParameters<Self::Backend> + Send,
T: FromRow<Self::Backend, O> + Send + Unpin;
T: FromRow<Self::Backend> + Send + Unpin;
fn fetch_all<'e, 'q: 'e, I: 'e, O: 'e, T: 'e>(
fn fetch_all<'e, 'q: 'e, I: 'e, T: 'e>(
&'e mut self,
query: &'q str,
params: I,
) -> BoxFuture<'e, crate::Result<Vec<T>>>
where
I: IntoQueryParameters<Self::Backend> + Send,
T: FromRow<Self::Backend, O> + Send + Unpin,
T: FromRow<Self::Backend> + Send + Unpin,
{
Box::pin(self.fetch(query, params).try_collect())
}
fn fetch_optional<'e, 'q: 'e, I: 'e, O: 'e, T: 'e>(
fn fetch_optional<'e, 'q: 'e, I: 'e, T: 'e>(
&'e mut self,
query: &'q str,
params: I,
) -> BoxFuture<'e, crate::Result<Option<T>>>
where
I: IntoQueryParameters<Self::Backend> + Send,
T: FromRow<Self::Backend, O> + Send;
T: FromRow<Self::Backend> + Send;
fn fetch_one<'e, 'q: 'e, I: 'e, O: 'e, T: 'e>(
fn fetch_one<'e, 'q: 'e, I: 'e, T: 'e>(
&'e mut self,
query: &'q str,
params: I,
) -> BoxFuture<'e, crate::Result<T>>
where
I: IntoQueryParameters<Self::Backend> + Send,
T: FromRow<Self::Backend, O> + Send,
T: FromRow<Self::Backend> + Send,
{
let fut = self.fetch_optional(query, params);
Box::pin(async move { fut.await?.ok_or(Error::NotFound) })

View File

@ -50,6 +50,9 @@ pub use self::{
#[doc(hidden)]
pub use types::HasTypeMetadata;
#[doc(hidden)]
pub use describe::{Describe, ResultField};
#[cfg(feature = "mariadb")]
pub mod mariadb;

View File

@ -22,14 +22,14 @@ where
Box::pin(async move { <&Pool<DB> as Executor>::execute(&mut &*self, query, params).await })
}
fn fetch<'e, 'q: 'e, I: 'e, O: 'e, T: 'e>(
fn fetch<'e, 'q: 'e, I: 'e, T: 'e>(
&'e mut self,
query: &'q str,
params: I,
) -> BoxStream<'e, crate::Result<T>>
where
I: IntoQueryParameters<Self::Backend> + Send,
T: FromRow<Self::Backend, O> + Send + Unpin,
T: FromRow<Self::Backend> + Send + Unpin,
{
Box::pin(async_stream::try_stream! {
let mut self_ = &*self;
@ -41,14 +41,14 @@ where
})
}
fn fetch_optional<'e, 'q: 'e, I: 'e, O: 'e, T: 'e>(
fn fetch_optional<'e, 'q: 'e, I: 'e, T: 'e>(
&'e mut self,
query: &'q str,
params: I,
) -> BoxFuture<'e, crate::Result<Option<T>>>
where
I: IntoQueryParameters<Self::Backend> + Send,
T: FromRow<Self::Backend, O> + Send,
T: FromRow<Self::Backend> + Send,
{
Box::pin(async move {
<&Pool<DB> as Executor>::fetch_optional(&mut &*self, query, params).await
@ -80,14 +80,14 @@ where
Box::pin(async move { self.0.acquire().await?.execute(query, params).await })
}
fn fetch<'e, 'q: 'e, I: 'e, O: 'e, T: 'e>(
fn fetch<'e, 'q: 'e, I: 'e, T: 'e>(
&'e mut self,
query: &'q str,
params: I,
) -> BoxStream<'e, crate::Result<T>>
where
I: IntoQueryParameters<Self::Backend> + Send,
T: FromRow<Self::Backend, O> + Send + Unpin,
T: FromRow<Self::Backend> + Send + Unpin,
{
Box::pin(async_stream::try_stream! {
let mut live = self.0.acquire().await?;
@ -99,14 +99,14 @@ where
})
}
fn fetch_optional<'e, 'q: 'e, I: 'e, O: 'e, T: 'e>(
fn fetch_optional<'e, 'q: 'e, I: 'e, T: 'e>(
&'e mut self,
query: &'q str,
params: I,
) -> BoxFuture<'e, crate::Result<Option<T>>>
where
I: IntoQueryParameters<Self::Backend> + Send,
T: FromRow<Self::Backend, O> + Send,
T: FromRow<Self::Backend> + Send,
{
Box::pin(async move { self.0.acquire().await?.fetch_optional(query, params).await })
}

View File

@ -39,5 +39,5 @@ impl Backend for Postgres {
}
}
impl_from_row_for_backend!(Postgres);
impl_from_row_for_backend!(Postgres, DataRow);
impl_into_query_parameters_for_backend!(Postgres);

View File

@ -40,14 +40,14 @@ impl Executor for Postgres {
})
}
fn fetch<'e, 'q: 'e, I: 'e, O: 'e, T: 'e>(
fn fetch<'e, 'q: 'e, I: 'e, T: 'e>(
&'e mut self,
query: &'q str,
params: I,
) -> BoxStream<'e, crate::Result<T>>
where
I: IntoQueryParameters<Self::Backend> + Send,
T: FromRow<Self::Backend, O> + Send + Unpin,
T: FromRow<Self::Backend> + Send + Unpin,
{
let params = params.into_params();
@ -66,14 +66,14 @@ impl Executor for Postgres {
})
}
fn fetch_optional<'e, 'q: 'e, I: 'e, O: 'e, T: 'e>(
fn fetch_optional<'e, 'q: 'e, I: 'e, T: 'e>(
&'e mut self,
query: &'q str,
params: I,
) -> BoxFuture<'e, crate::Result<Option<T>>>
where
I: IntoQueryParameters<Self::Backend> + Send,
T: FromRow<Self::Backend, O> + Send,
T: FromRow<Self::Backend> + Send,
{
Box::pin(async move {
let params = params.into_params();
@ -137,7 +137,7 @@ impl Executor for Postgres {
.into_vec()
.into_iter()
.map(|field| ResultField {
name: Some(field.name),
name: if field.name == "?column?" { None } else { Some(field.name) },
table_id: Some(field.table_id),
type_id: field.type_id,
})

View File

@ -33,7 +33,7 @@ where
DB: Backend,
DB::QueryParameters: 'q,
I: IntoQueryParameters<DB> + Send,
O: FromRow<DB, O> + Send + Unpin,
O: FromRow<DB> + Send + Unpin,
{
#[inline]
pub fn execute<E>(self, executor: &'q mut E) -> BoxFuture<'q, crate::Result<u64>>
@ -72,7 +72,7 @@ where
}
}
impl<DB> Query<'_, DB, <DB as Backend>::QueryParameters>
impl<DB> Query<'_, DB>
where
DB: Backend,
{
@ -93,6 +93,17 @@ where
}
}
impl<'q, DB, I, O> Query<'q, DB, I, O> where DB: Backend {
pub fn with_output_type<O_>(self) -> Query<'q, DB, I, O_> {
Query {
query: self.query,
input: self.input,
output: PhantomData,
backend: PhantomData,
}
}
}
/// Construct a full SQL query using raw SQL.
#[inline]
pub fn query<DB>(query: &str) -> Query<'_, DB>

View File

@ -16,7 +16,7 @@ pub trait Row: Send {
}
}
pub trait FromRow<DB: Backend, O = <DB as Backend>::Row> {
pub trait FromRow<DB: Backend> {
fn from_row(row: <DB as Backend>::Row) -> Self;
}
@ -36,27 +36,13 @@ macro_rules! impl_from_row {
($(row.get($idx),)+)
}
}
// (T1, T2, ...) -> (T1, T2, ...)
impl<$($T,)+> crate::row::FromRow<$B, ($($T,)+)> for ($($T,)+)
where
$($B: crate::types::HasSqlType<$T>,)+
$($T: crate::decode::Decode<$B>,)+
{
#[inline]
fn from_row(row: <$B as crate::backend::Backend>::Row) -> Self {
use crate::row::Row;
($(row.get($idx),)+)
}
}
};
}
#[allow(unused)]
macro_rules! impl_from_row_for_backend {
($B:ident) => {
impl crate::row::FromRow<$B> for <$B as crate::backend::Backend>::Row where $B: crate::Backend {
($B:ident, $row:ident) => {
impl crate::row::FromRow<$B> for $row where $B: crate::Backend {
#[inline]
fn from_row(row: <$B as crate::backend::Backend>::Row) -> Self {
row

View File

@ -1,6 +1,8 @@
#[cfg(feature = "uuid")]
pub use uuid::Uuid;
use std::fmt::Display;
/// Information about how a backend stores metadata about
/// given SQL types.
pub trait HasTypeMetadata {
@ -8,7 +10,7 @@ pub trait HasTypeMetadata {
type TypeMetadata: TypeMetadata<Self::TypeId>;
/// The Rust type of type identifiers in `DESCRIBE` responses for the SQL backend.
type TypeId: Eq;
type TypeId: Eq + Display;
}
pub trait TypeMetadata<TypeId: Eq> {

View File

@ -3,6 +3,10 @@ use sqlx::Backend;
pub trait BackendExt: Backend {
const BACKEND_PATH: &'static str;
fn quotable_path() -> syn::Path {
syn::parse_str(Self::BACKEND_PATH).unwrap()
}
fn param_type_for_id(id: &Self::TypeId) -> Option<&'static str>;
fn return_type_for_id(id: &Self::TypeId) -> Option<&'static str>;

View File

@ -1,26 +1,19 @@
#![cfg_attr(not(any(feature = "postgres", feature = "mariadb")), allow(dead_code, unused_macros, unused_imports))]
extern crate proc_macro;
use proc_macro::TokenStream;
use proc_macro2::Span;
use proc_macro_hack::proc_macro_hack;
use quote::{quote, quote_spanned, ToTokens};
use quote::{quote};
use syn::{
parse::{self, Parse, ParseStream},
parse,
parse_macro_input,
punctuated::Punctuated,
spanned::Spanned,
Expr, ExprLit, Lit, Token,
};
use sqlx::{Executor, HasTypeMetadata};
use async_std::task;
use std::fmt::Display;
use url::Url;
type Error = Box<dyn std::error::Error>;
@ -28,46 +21,55 @@ type Result<T> = std::result::Result<T, Error>;
mod backend;
use backend::BackendExt;
mod query;
struct MacroInput {
sql: String,
sql_span: Span,
args: Vec<Expr>,
}
macro_rules! with_database(
($db:ident => $expr:expr) => {
async {
let db_url = Url::parse(&dotenv::var("DATABASE_URL").map_err(|_| "DATABASE_URL not set")?)?;
impl Parse for MacroInput {
fn parse(input: ParseStream) -> parse::Result<Self> {
let mut args = Punctuated::<Expr, Token![,]>::parse_terminated(input)?.into_iter();
match db_url.scheme() {
#[cfg(feature = "postgres")]
"postgresql" | "postgres" => {
let $db = sqlx::Connection::<sqlx::Postgres>::open(db_url.as_str())
.await
.map_err(|e| format!("failed to connect to database: {}", e))?;
let sql = match args.next() {
Some(Expr::Lit(ExprLit {
lit: Lit::Str(sql), ..
})) => sql,
Some(other_expr) => {
return Err(parse::Error::new_spanned(
other_expr,
"expected string literal",
));
$expr.await
}
#[cfg(not(feature = "postgres"))]
"postgresql" | "postgres" => Err(format!(
"DATABASE_URL {} has the scheme of a Postgres database but the `postgres` \
feature of sqlx was not enabled",
db_url
).into()),
#[cfg(feature = "mariadb")]
"mysql" | "mariadb" => {
let $db = sqlx::Connection::<sqlx::MariaDb>::open(db_url.as_str())
.await
.map_err(|e| format!("failed to connect to database: {}", e))?;
$expr.await
}
#[cfg(not(feature = "mariadb"))]
"mysql" | "mariadb" => Err(format!(
"DATABASE_URL {} has the scheme of a MySQL/MariaDB database but the `mariadb` \
feature of sqlx was not enabled",
db_url
).into()),
scheme => Err(format!("unexpected scheme {:?} in DATABASE_URL {}", scheme, db_url).into()),
}
None => return Err(input.error("expected SQL string literal")),
};
Ok(MacroInput {
sql: sql.value(),
sql_span: sql.span(),
args: args.collect(),
})
}
}
}
);
#[proc_macro_hack]
pub fn query(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as MacroInput);
let input = parse_macro_input!(input as query::MacroInput);
match task::block_on(process_sql(input)) {
Ok(ts) => ts,
Err(e) => {
match task::block_on(with_database!(db => query::process_sql(input, db))) {
Ok(ts) => ts.into(),
Result::Err(e) => {
if let Some(parse_err) = e.downcast_ref::<parse::Error>() {
return parse_err.to_compile_error().into();
}
@ -77,131 +79,3 @@ pub fn query(input: TokenStream) -> TokenStream {
}
}
}
async fn process_sql(input: MacroInput) -> Result<TokenStream> {
let db_url = Url::parse(&dotenv::var("DATABASE_URL")?)?;
match db_url.scheme() {
#[cfg(feature = "postgres")]
"postgresql" | "postgres" => {
process_sql_with(
input,
sqlx::Connection::<sqlx::Postgres>::open(db_url.as_str())
.await
.map_err(|e| format!("failed to connect to database: {}", e))?,
)
.await
}
#[cfg(feature = "mariadb")]
"mysql" | "mariadb" => {
process_sql_with(
input,
sqlx::Connection::<sqlx::MariaDb>::open(db_url.as_str())
.await
.map_err(|e| format!("failed to connect to database: {}", e))?,
)
.await
}
scheme => Err(format!("unexpected scheme {:?} in DB_URL {}", scheme, db_url).into()),
}
}
async fn process_sql_with<DB: BackendExt>(
input: MacroInput,
mut conn: sqlx::Connection<DB>,
) -> Result<TokenStream>
where
<DB as HasTypeMetadata>::TypeId: Display,
{
let prepared = conn
.describe(&input.sql)
.await
.map_err(|e| parse::Error::new(input.sql_span, e))?;
if input.args.len() != prepared.param_types.len() {
return Err(parse::Error::new(
Span::call_site(),
format!(
"expected {} parameters, got {}",
prepared.param_types.len(),
input.args.len()
),
)
.into());
}
let param_types = prepared
.param_types
.iter()
.zip(&*input.args)
.map(|(type_, expr)| {
get_type_override(expr)
.or_else(|| {
Some(
<DB as BackendExt>::param_type_for_id(type_)?
.parse::<proc_macro2::TokenStream>()
.unwrap(),
)
})
.ok_or_else(|| format!("unknown type param ID: {}", type_).into())
})
.collect::<Result<Vec<_>>>()?;
let output_types = prepared
.result_fields
.iter()
.map(|column| {
Ok(<DB as BackendExt>::return_type_for_id(&column.type_id)
.ok_or_else(|| format!("unknown field type ID: {}", &column.type_id))?
.parse::<proc_macro2::TokenStream>()
.unwrap())
})
.collect::<Result<Vec<_>>>()?;
let params_ty_cons = input.args.iter().enumerate().map(|(i, expr)| {
// required or `quote!()` emits it as `Nusize`
let i = syn::Index::from(i);
quote_spanned!( expr.span() => { use sqlx::TyConsExt as _; (sqlx::TyCons::new(&params.#i)).ty_cons() })
});
let query = &input.sql;
let backend_path = syn::parse_str::<syn::Path>(DB::BACKEND_PATH).unwrap();
let params = if input.args.is_empty() {
quote! {
let params = ();
}
} else {
let params = input.args.iter();
quote! {
let params = (#(#params),*,);
if false {
use sqlx::TyConsExt as _;
let _: (#(#param_types),*,) = (#(#params_ty_cons),*,);
}
}
};
Ok(quote! {{
#params
sqlx::Query::<#backend_path, _, (#(#output_types),*,)> {
query: #query,
input: params,
output: ::core::marker::PhantomData,
backend: ::core::marker::PhantomData,
}
}}
.into())
}
fn get_type_override(expr: &Expr) -> Option<proc_macro2::TokenStream> {
match expr {
Expr::Cast(cast) => Some(cast.ty.to_token_stream()),
Expr::Type(ascription) => Some(ascription.ty.to_token_stream()),
_ => None,
}
}

200
sqlx-macros/src/query.rs Normal file
View File

@ -0,0 +1,200 @@
use proc_macro2::Span;
use proc_macro2::TokenStream;
use syn::{
parse::{self, Parse, ParseStream},
punctuated::Punctuated,
spanned::Spanned,
Expr, ExprLit, Lit, Token, Ident,
};
use crate::backend::BackendExt;
use quote::{format_ident, quote, quote_spanned, ToTokens};
use sqlx::{Executor, HasTypeMetadata};
use std::fmt::Display;
pub struct MacroInput {
sql: String,
sql_span: Span,
args: Vec<Expr>,
}
impl Parse for MacroInput {
fn parse(input: ParseStream) -> parse::Result<Self> {
let mut args = Punctuated::<Expr, Token![,]>::parse_terminated(input)?.into_iter();
let sql = match args.next() {
Some(Expr::Lit(ExprLit {
lit: Lit::Str(sql), ..
})) => sql,
Some(other_expr) => {
return Err(parse::Error::new_spanned(
other_expr,
"expected string literal",
));
}
None => return Err(input.error("expected SQL string literal")),
};
Ok(MacroInput {
sql: sql.value(),
sql_span: sql.span(),
args: args.collect(),
})
}
}
/// Given an input like `query!("SELECT * FROM accounts WHERE account_id > ?", account_id)`
pub async fn process_sql<DB: BackendExt>(
input: MacroInput,
mut conn: sqlx::Connection<DB>,
) -> crate::Result<TokenStream>
where
<DB as HasTypeMetadata>::TypeId: Display,
{
let describe = conn
.describe(&input.sql)
.await
.map_err(|e| parse::Error::new(input.sql_span, e))?;
if input.args.len() != describe.param_types.len() {
return Err(parse::Error::new(
Span::call_site(),
format!(
"expected {} parameters, got {}",
describe.param_types.len(),
input.args.len()
),
)
.into());
}
let param_types = describe
.param_types
.iter()
.zip(&*input.args)
.map(|(type_, expr)| {
get_type_override(expr)
.or_else(|| {
Some(
<DB as BackendExt>::param_type_for_id(type_)?
.parse::<proc_macro2::TokenStream>()
.unwrap(),
)
})
.ok_or_else(|| format!("unknown type param ID: {}", type_).into())
})
.collect::<crate::Result<Vec<_>>>()?;
let params_ty_cons = input.args.iter().enumerate().map(|(i, expr)| {
// required or `quote!()` emits it as `Nusize`
let i = syn::Index::from(i);
quote_spanned!( expr.span() => { use sqlx::TyConsExt as _; (sqlx::TyCons::new(&params.#i)).ty_cons() })
});
let query = &input.sql;
let backend_path = DB::quotable_path();
// record_type will be wrapped in parens which the compiler ignores without a trailing comma
// e.g. (Foo) == Foo but (Foo,) = one-element tuple
// and giving an empty stream for record_type makes it unit `()`
let (record_type, record) = if describe.result_fields.is_empty() {
(TokenStream::new(), TokenStream::new())
} else {
let record_type = Ident::new("Record", Span::call_site());
(record_type.to_token_stream(), generate_record_def(&describe, &record_type)?)
};
let params = if input.args.is_empty() {
quote! {
let params = ();
}
} else {
let params = input.args.iter();
quote! {
let params = (#(#params),*,);
if false {
use sqlx::TyConsExt as _;
let _: (#(#param_types),*,) = (#(#params_ty_cons),*,);
}
}
};
Ok(quote! {{
#record
#params
sqlx::Query::<#backend_path, _, (#record_type)> {
query: #query,
input: params,
output: ::core::marker::PhantomData,
backend: ::core::marker::PhantomData,
}
}})
}
fn generate_record_def<DB: BackendExt>(describe: &sqlx::Describe<DB>, type_name: &Ident) -> crate::Result<TokenStream> {
let fields = describe.result_fields.iter().enumerate()
.map(|(i, column)| {
let name = column.name.as_deref()
.ok_or_else(|| format!("column at position {} must have a name", i))?;
let name = syn::parse_str::<Ident>(name)
.map_err(|_| format!("{:?} is not a valid Rust identifier", name))?;
let type_ = <DB as BackendExt>::return_type_for_id(&column.type_id)
.ok_or_else(|| format!("unknown field type ID: {}", &column.type_id))?
.parse::<proc_macro2::TokenStream>()
.unwrap();
Ok((name, type_))
})
.collect::<Result<Vec<_>, String>>()
.map_err(|e| format!("all SQL result columns must be named with valid Rust identifiers: {}", e))?;
let row_param = format_ident!("row");
let record_fields = fields.iter()
.map(|(name, type_)| quote!(#name: #type_,))
.collect::<TokenStream>();
let instantiations = fields.iter()
.enumerate()
.map(|(i, (name, _))| quote!(#name: #row_param.get(#i),))
.collect::<TokenStream>();
let backend = DB::quotable_path();
Ok(quote! {
#[derive(Debug)]
struct #type_name {
#record_fields
}
impl sqlx::FromRow<#backend> for #type_name {
fn from_row(#row_param: <#backend as sqlx::Backend>::Row) -> Self {
use sqlx::Row as _;
#type_name {
#instantiations
}
}
}
})
}
fn get_type_override(expr: &Expr) -> Option<proc_macro2::TokenStream> {
match expr {
Expr::Cast(cast) => Some(cast.ty.to_token_stream()),
Expr::Type(ascription) => Some(ascription.ty.to_token_stream()),
_ => None,
}
}

View File

@ -3,14 +3,15 @@ use sqlx::types::Uuid;
#[async_std::test]
async fn postgres_query() -> sqlx::Result<()> {
let mut conn =
sqlx::Connection::<sqlx::Postgres>::open(&env::var("DATABASE_URL").unwrap()).await?;
sqlx::Connection::<sqlx::Postgres>::open(&dotenv::var("DATABASE_URL").unwrap()).await?;
let uuid: Uuid = "256ba9c8-0048-11ea-b0f0-8f04859d047e".parse().unwrap();
let accounts = sqlx::query!("SELECT * from accounts where id != $1", None)
.fetch_optional(&mut conn)
let account = sqlx::query!("SELECT * from accounts where id != $1", uuid)
.fetch_one(&mut conn)
.await?;
println!("accounts: {:?}", accounts);
println!("account ID: {:?}", account.id);
println!("account name: {}", account.name);
Ok(())
}