mirror of
https://github.com/launchbadge/sqlx.git
synced 2026-01-04 07:41:21 +00:00
make query!() output anonymous records
This commit is contained in:
parent
871183d23b
commit
acca40c88e
@ -32,6 +32,7 @@ proc-macro-hack = { version = "0.5.11", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
async-std = { version = "1.2.0", features = [ "attributes" ] }
|
||||
dotenv = "0.15.0"
|
||||
matches = "0.1.8"
|
||||
criterion = "0.3.0"
|
||||
|
||||
|
||||
@ -52,26 +52,26 @@ where
|
||||
self.live.execute(query, params)
|
||||
}
|
||||
|
||||
fn fetch<'e, 'q: 'e, I: 'e, O: 'e, T: 'e>(
|
||||
fn fetch<'e, 'q: 'e, I: 'e, T: 'e>(
|
||||
&'e mut self,
|
||||
query: &'q str,
|
||||
params: I,
|
||||
) -> BoxStream<'e, crate::Result<T>>
|
||||
where
|
||||
I: IntoQueryParameters<Self::Backend> + Send,
|
||||
T: FromRow<Self::Backend, O> + Send + Unpin,
|
||||
T: FromRow<Self::Backend> + Send + Unpin,
|
||||
{
|
||||
self.live.fetch(query, params)
|
||||
}
|
||||
|
||||
fn fetch_optional<'e, 'q: 'e, I: 'e, O: 'e, T: 'e>(
|
||||
fn fetch_optional<'e, 'q: 'e, I: 'e, T: 'e>(
|
||||
&'e mut self,
|
||||
query: &'q str,
|
||||
params: I,
|
||||
) -> BoxFuture<'e, crate::Result<Option<T>>>
|
||||
where
|
||||
I: IntoQueryParameters<Self::Backend> + Send,
|
||||
T: FromRow<Self::Backend, O> + Send,
|
||||
T: FromRow<Self::Backend> + Send,
|
||||
{
|
||||
self.live.fetch_optional(query, params)
|
||||
}
|
||||
|
||||
@ -30,44 +30,44 @@ pub trait Executor: Send {
|
||||
where
|
||||
I: IntoQueryParameters<Self::Backend> + Send;
|
||||
|
||||
fn fetch<'e, 'q: 'e, I: 'e, O: 'e, T: 'e>(
|
||||
fn fetch<'e, 'q: 'e, I: 'e, T: 'e>(
|
||||
&'e mut self,
|
||||
query: &'q str,
|
||||
params: I,
|
||||
) -> BoxStream<'e, crate::Result<T>>
|
||||
where
|
||||
I: IntoQueryParameters<Self::Backend> + Send,
|
||||
T: FromRow<Self::Backend, O> + Send + Unpin;
|
||||
T: FromRow<Self::Backend> + Send + Unpin;
|
||||
|
||||
fn fetch_all<'e, 'q: 'e, I: 'e, O: 'e, T: 'e>(
|
||||
fn fetch_all<'e, 'q: 'e, I: 'e, T: 'e>(
|
||||
&'e mut self,
|
||||
query: &'q str,
|
||||
params: I,
|
||||
) -> BoxFuture<'e, crate::Result<Vec<T>>>
|
||||
where
|
||||
I: IntoQueryParameters<Self::Backend> + Send,
|
||||
T: FromRow<Self::Backend, O> + Send + Unpin,
|
||||
T: FromRow<Self::Backend> + Send + Unpin,
|
||||
{
|
||||
Box::pin(self.fetch(query, params).try_collect())
|
||||
}
|
||||
|
||||
fn fetch_optional<'e, 'q: 'e, I: 'e, O: 'e, T: 'e>(
|
||||
fn fetch_optional<'e, 'q: 'e, I: 'e, T: 'e>(
|
||||
&'e mut self,
|
||||
query: &'q str,
|
||||
params: I,
|
||||
) -> BoxFuture<'e, crate::Result<Option<T>>>
|
||||
where
|
||||
I: IntoQueryParameters<Self::Backend> + Send,
|
||||
T: FromRow<Self::Backend, O> + Send;
|
||||
T: FromRow<Self::Backend> + Send;
|
||||
|
||||
fn fetch_one<'e, 'q: 'e, I: 'e, O: 'e, T: 'e>(
|
||||
fn fetch_one<'e, 'q: 'e, I: 'e, T: 'e>(
|
||||
&'e mut self,
|
||||
query: &'q str,
|
||||
params: I,
|
||||
) -> BoxFuture<'e, crate::Result<T>>
|
||||
where
|
||||
I: IntoQueryParameters<Self::Backend> + Send,
|
||||
T: FromRow<Self::Backend, O> + Send,
|
||||
T: FromRow<Self::Backend> + Send,
|
||||
{
|
||||
let fut = self.fetch_optional(query, params);
|
||||
Box::pin(async move { fut.await?.ok_or(Error::NotFound) })
|
||||
|
||||
@ -50,6 +50,9 @@ pub use self::{
|
||||
#[doc(hidden)]
|
||||
pub use types::HasTypeMetadata;
|
||||
|
||||
#[doc(hidden)]
|
||||
pub use describe::{Describe, ResultField};
|
||||
|
||||
#[cfg(feature = "mariadb")]
|
||||
pub mod mariadb;
|
||||
|
||||
|
||||
@ -22,14 +22,14 @@ where
|
||||
Box::pin(async move { <&Pool<DB> as Executor>::execute(&mut &*self, query, params).await })
|
||||
}
|
||||
|
||||
fn fetch<'e, 'q: 'e, I: 'e, O: 'e, T: 'e>(
|
||||
fn fetch<'e, 'q: 'e, I: 'e, T: 'e>(
|
||||
&'e mut self,
|
||||
query: &'q str,
|
||||
params: I,
|
||||
) -> BoxStream<'e, crate::Result<T>>
|
||||
where
|
||||
I: IntoQueryParameters<Self::Backend> + Send,
|
||||
T: FromRow<Self::Backend, O> + Send + Unpin,
|
||||
T: FromRow<Self::Backend> + Send + Unpin,
|
||||
{
|
||||
Box::pin(async_stream::try_stream! {
|
||||
let mut self_ = &*self;
|
||||
@ -41,14 +41,14 @@ where
|
||||
})
|
||||
}
|
||||
|
||||
fn fetch_optional<'e, 'q: 'e, I: 'e, O: 'e, T: 'e>(
|
||||
fn fetch_optional<'e, 'q: 'e, I: 'e, T: 'e>(
|
||||
&'e mut self,
|
||||
query: &'q str,
|
||||
params: I,
|
||||
) -> BoxFuture<'e, crate::Result<Option<T>>>
|
||||
where
|
||||
I: IntoQueryParameters<Self::Backend> + Send,
|
||||
T: FromRow<Self::Backend, O> + Send,
|
||||
T: FromRow<Self::Backend> + Send,
|
||||
{
|
||||
Box::pin(async move {
|
||||
<&Pool<DB> as Executor>::fetch_optional(&mut &*self, query, params).await
|
||||
@ -80,14 +80,14 @@ where
|
||||
Box::pin(async move { self.0.acquire().await?.execute(query, params).await })
|
||||
}
|
||||
|
||||
fn fetch<'e, 'q: 'e, I: 'e, O: 'e, T: 'e>(
|
||||
fn fetch<'e, 'q: 'e, I: 'e, T: 'e>(
|
||||
&'e mut self,
|
||||
query: &'q str,
|
||||
params: I,
|
||||
) -> BoxStream<'e, crate::Result<T>>
|
||||
where
|
||||
I: IntoQueryParameters<Self::Backend> + Send,
|
||||
T: FromRow<Self::Backend, O> + Send + Unpin,
|
||||
T: FromRow<Self::Backend> + Send + Unpin,
|
||||
{
|
||||
Box::pin(async_stream::try_stream! {
|
||||
let mut live = self.0.acquire().await?;
|
||||
@ -99,14 +99,14 @@ where
|
||||
})
|
||||
}
|
||||
|
||||
fn fetch_optional<'e, 'q: 'e, I: 'e, O: 'e, T: 'e>(
|
||||
fn fetch_optional<'e, 'q: 'e, I: 'e, T: 'e>(
|
||||
&'e mut self,
|
||||
query: &'q str,
|
||||
params: I,
|
||||
) -> BoxFuture<'e, crate::Result<Option<T>>>
|
||||
where
|
||||
I: IntoQueryParameters<Self::Backend> + Send,
|
||||
T: FromRow<Self::Backend, O> + Send,
|
||||
T: FromRow<Self::Backend> + Send,
|
||||
{
|
||||
Box::pin(async move { self.0.acquire().await?.fetch_optional(query, params).await })
|
||||
}
|
||||
|
||||
@ -39,5 +39,5 @@ impl Backend for Postgres {
|
||||
}
|
||||
}
|
||||
|
||||
impl_from_row_for_backend!(Postgres);
|
||||
impl_from_row_for_backend!(Postgres, DataRow);
|
||||
impl_into_query_parameters_for_backend!(Postgres);
|
||||
|
||||
@ -40,14 +40,14 @@ impl Executor for Postgres {
|
||||
})
|
||||
}
|
||||
|
||||
fn fetch<'e, 'q: 'e, I: 'e, O: 'e, T: 'e>(
|
||||
fn fetch<'e, 'q: 'e, I: 'e, T: 'e>(
|
||||
&'e mut self,
|
||||
query: &'q str,
|
||||
params: I,
|
||||
) -> BoxStream<'e, crate::Result<T>>
|
||||
where
|
||||
I: IntoQueryParameters<Self::Backend> + Send,
|
||||
T: FromRow<Self::Backend, O> + Send + Unpin,
|
||||
T: FromRow<Self::Backend> + Send + Unpin,
|
||||
{
|
||||
let params = params.into_params();
|
||||
|
||||
@ -66,14 +66,14 @@ impl Executor for Postgres {
|
||||
})
|
||||
}
|
||||
|
||||
fn fetch_optional<'e, 'q: 'e, I: 'e, O: 'e, T: 'e>(
|
||||
fn fetch_optional<'e, 'q: 'e, I: 'e, T: 'e>(
|
||||
&'e mut self,
|
||||
query: &'q str,
|
||||
params: I,
|
||||
) -> BoxFuture<'e, crate::Result<Option<T>>>
|
||||
where
|
||||
I: IntoQueryParameters<Self::Backend> + Send,
|
||||
T: FromRow<Self::Backend, O> + Send,
|
||||
T: FromRow<Self::Backend> + Send,
|
||||
{
|
||||
Box::pin(async move {
|
||||
let params = params.into_params();
|
||||
@ -137,7 +137,7 @@ impl Executor for Postgres {
|
||||
.into_vec()
|
||||
.into_iter()
|
||||
.map(|field| ResultField {
|
||||
name: Some(field.name),
|
||||
name: if field.name == "?column?" { None } else { Some(field.name) },
|
||||
table_id: Some(field.table_id),
|
||||
type_id: field.type_id,
|
||||
})
|
||||
|
||||
@ -33,7 +33,7 @@ where
|
||||
DB: Backend,
|
||||
DB::QueryParameters: 'q,
|
||||
I: IntoQueryParameters<DB> + Send,
|
||||
O: FromRow<DB, O> + Send + Unpin,
|
||||
O: FromRow<DB> + Send + Unpin,
|
||||
{
|
||||
#[inline]
|
||||
pub fn execute<E>(self, executor: &'q mut E) -> BoxFuture<'q, crate::Result<u64>>
|
||||
@ -72,7 +72,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<DB> Query<'_, DB, <DB as Backend>::QueryParameters>
|
||||
impl<DB> Query<'_, DB>
|
||||
where
|
||||
DB: Backend,
|
||||
{
|
||||
@ -93,6 +93,17 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<'q, DB, I, O> Query<'q, DB, I, O> where DB: Backend {
|
||||
pub fn with_output_type<O_>(self) -> Query<'q, DB, I, O_> {
|
||||
Query {
|
||||
query: self.query,
|
||||
input: self.input,
|
||||
output: PhantomData,
|
||||
backend: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Construct a full SQL query using raw SQL.
|
||||
#[inline]
|
||||
pub fn query<DB>(query: &str) -> Query<'_, DB>
|
||||
|
||||
@ -16,7 +16,7 @@ pub trait Row: Send {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait FromRow<DB: Backend, O = <DB as Backend>::Row> {
|
||||
pub trait FromRow<DB: Backend> {
|
||||
fn from_row(row: <DB as Backend>::Row) -> Self;
|
||||
}
|
||||
|
||||
@ -36,27 +36,13 @@ macro_rules! impl_from_row {
|
||||
($(row.get($idx),)+)
|
||||
}
|
||||
}
|
||||
|
||||
// (T1, T2, ...) -> (T1, T2, ...)
|
||||
impl<$($T,)+> crate::row::FromRow<$B, ($($T,)+)> for ($($T,)+)
|
||||
where
|
||||
$($B: crate::types::HasSqlType<$T>,)+
|
||||
$($T: crate::decode::Decode<$B>,)+
|
||||
{
|
||||
#[inline]
|
||||
fn from_row(row: <$B as crate::backend::Backend>::Row) -> Self {
|
||||
use crate::row::Row;
|
||||
|
||||
($(row.get($idx),)+)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
macro_rules! impl_from_row_for_backend {
|
||||
($B:ident) => {
|
||||
impl crate::row::FromRow<$B> for <$B as crate::backend::Backend>::Row where $B: crate::Backend {
|
||||
($B:ident, $row:ident) => {
|
||||
impl crate::row::FromRow<$B> for $row where $B: crate::Backend {
|
||||
#[inline]
|
||||
fn from_row(row: <$B as crate::backend::Backend>::Row) -> Self {
|
||||
row
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
#[cfg(feature = "uuid")]
|
||||
pub use uuid::Uuid;
|
||||
|
||||
use std::fmt::Display;
|
||||
|
||||
/// Information about how a backend stores metadata about
|
||||
/// given SQL types.
|
||||
pub trait HasTypeMetadata {
|
||||
@ -8,7 +10,7 @@ pub trait HasTypeMetadata {
|
||||
type TypeMetadata: TypeMetadata<Self::TypeId>;
|
||||
|
||||
/// The Rust type of type identifiers in `DESCRIBE` responses for the SQL backend.
|
||||
type TypeId: Eq;
|
||||
type TypeId: Eq + Display;
|
||||
}
|
||||
|
||||
pub trait TypeMetadata<TypeId: Eq> {
|
||||
|
||||
@ -3,6 +3,10 @@ use sqlx::Backend;
|
||||
pub trait BackendExt: Backend {
|
||||
const BACKEND_PATH: &'static str;
|
||||
|
||||
fn quotable_path() -> syn::Path {
|
||||
syn::parse_str(Self::BACKEND_PATH).unwrap()
|
||||
}
|
||||
|
||||
fn param_type_for_id(id: &Self::TypeId) -> Option<&'static str>;
|
||||
|
||||
fn return_type_for_id(id: &Self::TypeId) -> Option<&'static str>;
|
||||
|
||||
@ -1,26 +1,19 @@
|
||||
#![cfg_attr(not(any(feature = "postgres", feature = "mariadb")), allow(dead_code, unused_macros, unused_imports))]
|
||||
extern crate proc_macro;
|
||||
|
||||
use proc_macro::TokenStream;
|
||||
|
||||
use proc_macro2::Span;
|
||||
|
||||
use proc_macro_hack::proc_macro_hack;
|
||||
|
||||
use quote::{quote, quote_spanned, ToTokens};
|
||||
use quote::{quote};
|
||||
|
||||
use syn::{
|
||||
parse::{self, Parse, ParseStream},
|
||||
parse,
|
||||
parse_macro_input,
|
||||
punctuated::Punctuated,
|
||||
spanned::Spanned,
|
||||
Expr, ExprLit, Lit, Token,
|
||||
};
|
||||
|
||||
use sqlx::{Executor, HasTypeMetadata};
|
||||
|
||||
use async_std::task;
|
||||
|
||||
use std::fmt::Display;
|
||||
use url::Url;
|
||||
|
||||
type Error = Box<dyn std::error::Error>;
|
||||
@ -28,46 +21,55 @@ type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
mod backend;
|
||||
|
||||
use backend::BackendExt;
|
||||
mod query;
|
||||
|
||||
struct MacroInput {
|
||||
sql: String,
|
||||
sql_span: Span,
|
||||
args: Vec<Expr>,
|
||||
}
|
||||
macro_rules! with_database(
|
||||
($db:ident => $expr:expr) => {
|
||||
async {
|
||||
let db_url = Url::parse(&dotenv::var("DATABASE_URL").map_err(|_| "DATABASE_URL not set")?)?;
|
||||
|
||||
impl Parse for MacroInput {
|
||||
fn parse(input: ParseStream) -> parse::Result<Self> {
|
||||
let mut args = Punctuated::<Expr, Token![,]>::parse_terminated(input)?.into_iter();
|
||||
match db_url.scheme() {
|
||||
#[cfg(feature = "postgres")]
|
||||
"postgresql" | "postgres" => {
|
||||
let $db = sqlx::Connection::<sqlx::Postgres>::open(db_url.as_str())
|
||||
.await
|
||||
.map_err(|e| format!("failed to connect to database: {}", e))?;
|
||||
|
||||
let sql = match args.next() {
|
||||
Some(Expr::Lit(ExprLit {
|
||||
lit: Lit::Str(sql), ..
|
||||
})) => sql,
|
||||
Some(other_expr) => {
|
||||
return Err(parse::Error::new_spanned(
|
||||
other_expr,
|
||||
"expected string literal",
|
||||
));
|
||||
$expr.await
|
||||
}
|
||||
#[cfg(not(feature = "postgres"))]
|
||||
"postgresql" | "postgres" => Err(format!(
|
||||
"DATABASE_URL {} has the scheme of a Postgres database but the `postgres` \
|
||||
feature of sqlx was not enabled",
|
||||
db_url
|
||||
).into()),
|
||||
#[cfg(feature = "mariadb")]
|
||||
"mysql" | "mariadb" => {
|
||||
let $db = sqlx::Connection::<sqlx::MariaDb>::open(db_url.as_str())
|
||||
.await
|
||||
.map_err(|e| format!("failed to connect to database: {}", e))?;
|
||||
|
||||
$expr.await
|
||||
}
|
||||
#[cfg(not(feature = "mariadb"))]
|
||||
"mysql" | "mariadb" => Err(format!(
|
||||
"DATABASE_URL {} has the scheme of a MySQL/MariaDB database but the `mariadb` \
|
||||
feature of sqlx was not enabled",
|
||||
db_url
|
||||
).into()),
|
||||
scheme => Err(format!("unexpected scheme {:?} in DATABASE_URL {}", scheme, db_url).into()),
|
||||
}
|
||||
None => return Err(input.error("expected SQL string literal")),
|
||||
};
|
||||
|
||||
Ok(MacroInput {
|
||||
sql: sql.value(),
|
||||
sql_span: sql.span(),
|
||||
args: args.collect(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
#[proc_macro_hack]
|
||||
pub fn query(input: TokenStream) -> TokenStream {
|
||||
let input = parse_macro_input!(input as MacroInput);
|
||||
let input = parse_macro_input!(input as query::MacroInput);
|
||||
|
||||
match task::block_on(process_sql(input)) {
|
||||
Ok(ts) => ts,
|
||||
Err(e) => {
|
||||
match task::block_on(with_database!(db => query::process_sql(input, db))) {
|
||||
Ok(ts) => ts.into(),
|
||||
Result::Err(e) => {
|
||||
if let Some(parse_err) = e.downcast_ref::<parse::Error>() {
|
||||
return parse_err.to_compile_error().into();
|
||||
}
|
||||
@ -77,131 +79,3 @@ pub fn query(input: TokenStream) -> TokenStream {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn process_sql(input: MacroInput) -> Result<TokenStream> {
|
||||
let db_url = Url::parse(&dotenv::var("DATABASE_URL")?)?;
|
||||
|
||||
match db_url.scheme() {
|
||||
#[cfg(feature = "postgres")]
|
||||
"postgresql" | "postgres" => {
|
||||
process_sql_with(
|
||||
input,
|
||||
sqlx::Connection::<sqlx::Postgres>::open(db_url.as_str())
|
||||
.await
|
||||
.map_err(|e| format!("failed to connect to database: {}", e))?,
|
||||
)
|
||||
.await
|
||||
}
|
||||
#[cfg(feature = "mariadb")]
|
||||
"mysql" | "mariadb" => {
|
||||
process_sql_with(
|
||||
input,
|
||||
sqlx::Connection::<sqlx::MariaDb>::open(db_url.as_str())
|
||||
.await
|
||||
.map_err(|e| format!("failed to connect to database: {}", e))?,
|
||||
)
|
||||
.await
|
||||
}
|
||||
scheme => Err(format!("unexpected scheme {:?} in DB_URL {}", scheme, db_url).into()),
|
||||
}
|
||||
}
|
||||
|
||||
async fn process_sql_with<DB: BackendExt>(
|
||||
input: MacroInput,
|
||||
mut conn: sqlx::Connection<DB>,
|
||||
) -> Result<TokenStream>
|
||||
where
|
||||
<DB as HasTypeMetadata>::TypeId: Display,
|
||||
{
|
||||
let prepared = conn
|
||||
.describe(&input.sql)
|
||||
.await
|
||||
.map_err(|e| parse::Error::new(input.sql_span, e))?;
|
||||
|
||||
if input.args.len() != prepared.param_types.len() {
|
||||
return Err(parse::Error::new(
|
||||
Span::call_site(),
|
||||
format!(
|
||||
"expected {} parameters, got {}",
|
||||
prepared.param_types.len(),
|
||||
input.args.len()
|
||||
),
|
||||
)
|
||||
.into());
|
||||
}
|
||||
|
||||
let param_types = prepared
|
||||
.param_types
|
||||
.iter()
|
||||
.zip(&*input.args)
|
||||
.map(|(type_, expr)| {
|
||||
get_type_override(expr)
|
||||
.or_else(|| {
|
||||
Some(
|
||||
<DB as BackendExt>::param_type_for_id(type_)?
|
||||
.parse::<proc_macro2::TokenStream>()
|
||||
.unwrap(),
|
||||
)
|
||||
})
|
||||
.ok_or_else(|| format!("unknown type param ID: {}", type_).into())
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let output_types = prepared
|
||||
.result_fields
|
||||
.iter()
|
||||
.map(|column| {
|
||||
Ok(<DB as BackendExt>::return_type_for_id(&column.type_id)
|
||||
.ok_or_else(|| format!("unknown field type ID: {}", &column.type_id))?
|
||||
.parse::<proc_macro2::TokenStream>()
|
||||
.unwrap())
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let params_ty_cons = input.args.iter().enumerate().map(|(i, expr)| {
|
||||
// required or `quote!()` emits it as `Nusize`
|
||||
let i = syn::Index::from(i);
|
||||
quote_spanned!( expr.span() => { use sqlx::TyConsExt as _; (sqlx::TyCons::new(¶ms.#i)).ty_cons() })
|
||||
});
|
||||
|
||||
let query = &input.sql;
|
||||
let backend_path = syn::parse_str::<syn::Path>(DB::BACKEND_PATH).unwrap();
|
||||
|
||||
let params = if input.args.is_empty() {
|
||||
quote! {
|
||||
let params = ();
|
||||
}
|
||||
} else {
|
||||
let params = input.args.iter();
|
||||
|
||||
quote! {
|
||||
let params = (#(#params),*,);
|
||||
|
||||
if false {
|
||||
use sqlx::TyConsExt as _;
|
||||
|
||||
let _: (#(#param_types),*,) = (#(#params_ty_cons),*,);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok(quote! {{
|
||||
#params
|
||||
|
||||
sqlx::Query::<#backend_path, _, (#(#output_types),*,)> {
|
||||
query: #query,
|
||||
input: params,
|
||||
output: ::core::marker::PhantomData,
|
||||
backend: ::core::marker::PhantomData,
|
||||
}
|
||||
}}
|
||||
.into())
|
||||
}
|
||||
|
||||
fn get_type_override(expr: &Expr) -> Option<proc_macro2::TokenStream> {
|
||||
match expr {
|
||||
Expr::Cast(cast) => Some(cast.ty.to_token_stream()),
|
||||
Expr::Type(ascription) => Some(ascription.ty.to_token_stream()),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
200
sqlx-macros/src/query.rs
Normal file
200
sqlx-macros/src/query.rs
Normal file
@ -0,0 +1,200 @@
|
||||
use proc_macro2::Span;
|
||||
|
||||
use proc_macro2::TokenStream;
|
||||
|
||||
use syn::{
|
||||
parse::{self, Parse, ParseStream},
|
||||
punctuated::Punctuated,
|
||||
spanned::Spanned,
|
||||
Expr, ExprLit, Lit, Token, Ident,
|
||||
};
|
||||
|
||||
use crate::backend::BackendExt;
|
||||
|
||||
use quote::{format_ident, quote, quote_spanned, ToTokens};
|
||||
|
||||
use sqlx::{Executor, HasTypeMetadata};
|
||||
|
||||
use std::fmt::Display;
|
||||
|
||||
pub struct MacroInput {
|
||||
sql: String,
|
||||
sql_span: Span,
|
||||
args: Vec<Expr>,
|
||||
}
|
||||
|
||||
impl Parse for MacroInput {
|
||||
fn parse(input: ParseStream) -> parse::Result<Self> {
|
||||
let mut args = Punctuated::<Expr, Token![,]>::parse_terminated(input)?.into_iter();
|
||||
|
||||
let sql = match args.next() {
|
||||
Some(Expr::Lit(ExprLit {
|
||||
lit: Lit::Str(sql), ..
|
||||
})) => sql,
|
||||
Some(other_expr) => {
|
||||
return Err(parse::Error::new_spanned(
|
||||
other_expr,
|
||||
"expected string literal",
|
||||
));
|
||||
}
|
||||
None => return Err(input.error("expected SQL string literal")),
|
||||
};
|
||||
|
||||
Ok(MacroInput {
|
||||
sql: sql.value(),
|
||||
sql_span: sql.span(),
|
||||
args: args.collect(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Given an input like `query!("SELECT * FROM accounts WHERE account_id > ?", account_id)`
|
||||
pub async fn process_sql<DB: BackendExt>(
|
||||
input: MacroInput,
|
||||
mut conn: sqlx::Connection<DB>,
|
||||
) -> crate::Result<TokenStream>
|
||||
where
|
||||
<DB as HasTypeMetadata>::TypeId: Display,
|
||||
{
|
||||
let describe = conn
|
||||
.describe(&input.sql)
|
||||
.await
|
||||
.map_err(|e| parse::Error::new(input.sql_span, e))?;
|
||||
|
||||
if input.args.len() != describe.param_types.len() {
|
||||
return Err(parse::Error::new(
|
||||
Span::call_site(),
|
||||
format!(
|
||||
"expected {} parameters, got {}",
|
||||
describe.param_types.len(),
|
||||
input.args.len()
|
||||
),
|
||||
)
|
||||
.into());
|
||||
}
|
||||
|
||||
let param_types = describe
|
||||
.param_types
|
||||
.iter()
|
||||
.zip(&*input.args)
|
||||
.map(|(type_, expr)| {
|
||||
get_type_override(expr)
|
||||
.or_else(|| {
|
||||
Some(
|
||||
<DB as BackendExt>::param_type_for_id(type_)?
|
||||
.parse::<proc_macro2::TokenStream>()
|
||||
.unwrap(),
|
||||
)
|
||||
})
|
||||
.ok_or_else(|| format!("unknown type param ID: {}", type_).into())
|
||||
})
|
||||
.collect::<crate::Result<Vec<_>>>()?;
|
||||
|
||||
let params_ty_cons = input.args.iter().enumerate().map(|(i, expr)| {
|
||||
// required or `quote!()` emits it as `Nusize`
|
||||
let i = syn::Index::from(i);
|
||||
quote_spanned!( expr.span() => { use sqlx::TyConsExt as _; (sqlx::TyCons::new(¶ms.#i)).ty_cons() })
|
||||
});
|
||||
|
||||
let query = &input.sql;
|
||||
let backend_path = DB::quotable_path();
|
||||
|
||||
// record_type will be wrapped in parens which the compiler ignores without a trailing comma
|
||||
// e.g. (Foo) == Foo but (Foo,) = one-element tuple
|
||||
// and giving an empty stream for record_type makes it unit `()`
|
||||
let (record_type, record) = if describe.result_fields.is_empty() {
|
||||
(TokenStream::new(), TokenStream::new())
|
||||
} else {
|
||||
let record_type = Ident::new("Record", Span::call_site());
|
||||
(record_type.to_token_stream(), generate_record_def(&describe, &record_type)?)
|
||||
};
|
||||
|
||||
let params = if input.args.is_empty() {
|
||||
quote! {
|
||||
let params = ();
|
||||
}
|
||||
} else {
|
||||
let params = input.args.iter();
|
||||
|
||||
quote! {
|
||||
let params = (#(#params),*,);
|
||||
|
||||
if false {
|
||||
use sqlx::TyConsExt as _;
|
||||
|
||||
let _: (#(#param_types),*,) = (#(#params_ty_cons),*,);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok(quote! {{
|
||||
#record
|
||||
|
||||
#params
|
||||
|
||||
sqlx::Query::<#backend_path, _, (#record_type)> {
|
||||
query: #query,
|
||||
input: params,
|
||||
output: ::core::marker::PhantomData,
|
||||
backend: ::core::marker::PhantomData,
|
||||
}
|
||||
}})
|
||||
}
|
||||
|
||||
fn generate_record_def<DB: BackendExt>(describe: &sqlx::Describe<DB>, type_name: &Ident) -> crate::Result<TokenStream> {
|
||||
let fields = describe.result_fields.iter().enumerate()
|
||||
.map(|(i, column)| {
|
||||
let name = column.name.as_deref()
|
||||
.ok_or_else(|| format!("column at position {} must have a name", i))?;
|
||||
|
||||
let name = syn::parse_str::<Ident>(name)
|
||||
.map_err(|_| format!("{:?} is not a valid Rust identifier", name))?;
|
||||
|
||||
let type_ = <DB as BackendExt>::return_type_for_id(&column.type_id)
|
||||
.ok_or_else(|| format!("unknown field type ID: {}", &column.type_id))?
|
||||
.parse::<proc_macro2::TokenStream>()
|
||||
.unwrap();
|
||||
|
||||
Ok((name, type_))
|
||||
})
|
||||
.collect::<Result<Vec<_>, String>>()
|
||||
.map_err(|e| format!("all SQL result columns must be named with valid Rust identifiers: {}", e))?;
|
||||
|
||||
let row_param = format_ident!("row");
|
||||
|
||||
let record_fields = fields.iter()
|
||||
.map(|(name, type_)| quote!(#name: #type_,))
|
||||
.collect::<TokenStream>();
|
||||
|
||||
let instantiations = fields.iter()
|
||||
.enumerate()
|
||||
.map(|(i, (name, _))| quote!(#name: #row_param.get(#i),))
|
||||
.collect::<TokenStream>();
|
||||
|
||||
let backend = DB::quotable_path();
|
||||
|
||||
Ok(quote! {
|
||||
#[derive(Debug)]
|
||||
struct #type_name {
|
||||
#record_fields
|
||||
}
|
||||
|
||||
impl sqlx::FromRow<#backend> for #type_name {
|
||||
fn from_row(#row_param: <#backend as sqlx::Backend>::Row) -> Self {
|
||||
use sqlx::Row as _;
|
||||
|
||||
#type_name {
|
||||
#instantiations
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn get_type_override(expr: &Expr) -> Option<proc_macro2::TokenStream> {
|
||||
match expr {
|
||||
Expr::Cast(cast) => Some(cast.ty.to_token_stream()),
|
||||
Expr::Type(ascription) => Some(ascription.ty.to_token_stream()),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
@ -3,14 +3,15 @@ use sqlx::types::Uuid;
|
||||
#[async_std::test]
|
||||
async fn postgres_query() -> sqlx::Result<()> {
|
||||
let mut conn =
|
||||
sqlx::Connection::<sqlx::Postgres>::open(&env::var("DATABASE_URL").unwrap()).await?;
|
||||
sqlx::Connection::<sqlx::Postgres>::open(&dotenv::var("DATABASE_URL").unwrap()).await?;
|
||||
|
||||
let uuid: Uuid = "256ba9c8-0048-11ea-b0f0-8f04859d047e".parse().unwrap();
|
||||
let accounts = sqlx::query!("SELECT * from accounts where id != $1", None)
|
||||
.fetch_optional(&mut conn)
|
||||
let account = sqlx::query!("SELECT * from accounts where id != $1", uuid)
|
||||
.fetch_one(&mut conn)
|
||||
.await?;
|
||||
|
||||
println!("accounts: {:?}", accounts);
|
||||
println!("account ID: {:?}", account.id);
|
||||
println!("account name: {}", account.name);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user