From 020eed90c886fa71a35c14a770a807e93c1a9efb Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Thu, 21 Nov 2019 21:38:52 +0000 Subject: [PATCH] port to async-std, misc fixes --- .gitignore | 1 + Cargo.toml | 13 ++++++------- sqlx-macros/Cargo.toml | 3 ++- sqlx-macros/src/backend/mod.rs | 6 +++++- sqlx-macros/src/lib.rs | 24 +++++++++++------------- src/connection.rs | 6 +++--- src/describe.rs | 26 +++++++++++++++++++++++++- src/io/buf_stream.rs | 9 ++++++--- src/mariadb/connection.rs | 19 +++++++++++-------- src/pool.rs | 6 +++--- src/postgres/connection.rs | 4 ++-- src/postgres/raw.rs | 2 +- src/query.rs | 8 ++++---- tests/sql-macro-test.rs | 6 +++--- 14 files changed, 83 insertions(+), 50 deletions(-) diff --git a/.gitignore b/.gitignore index fb667367..0be49a38 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ Cargo.lock .idea/ *.vim *.vi +.env diff --git a/Cargo.toml b/Cargo.toml index 2ccf58f5..1b65b982 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,27 +15,26 @@ postgres = [] mariadb = [] [dependencies] -async-stream = "0.1.1" +async-std = { version = "1.1", features = ["attributes"] } +async-stream = "0.2" async-trait = "0.1.11" bitflags = "1.1.0" byteorder = { version = "1.3.2", default-features = false } bytes = "0.4.12" crossbeam-queue = "0.1.2" crossbeam-utils = { version = "0.6.6", default-features = false } -futures-channel-preview = "0.3.0-alpha.18" -futures-core-preview = "0.3.0-alpha.18" -futures-util-preview = "0.3.0-alpha.18" +futures-channel = "0.3.1" +futures-core = "0.3.1" +futures-util = "0.3.1" log = "0.4.8" md-5 = "0.8.0" memchr = "2.2.1" -tokio = { version = "0.2.0-alpha.4", default-features = false, features = [ "tcp" ] } url = "2.1.0" uuid = { version = "0.8.1", optional = true } [dev-dependencies] matches = "0.1.8" -tokio = { version = "0.2.0-alpha.4", default-features = false, features = [ "rt-full" ] } -sqlx-macros = { path = "sqlx-macros/" } +sqlx-macros = { path = "sqlx-macros/", features = ["postgres", "mariadb", "uuid"] } criterion = "0.3" [profile.release] diff --git a/sqlx-macros/Cargo.toml b/sqlx-macros/Cargo.toml index 2c6d5e99..4148dc96 100644 --- a/sqlx-macros/Cargo.toml +++ b/sqlx-macros/Cargo.toml @@ -8,8 +8,9 @@ edition = "2018" proc-macro = true [dependencies] +async-std = "1.0" dotenv = "0.15.0" -futures-preview = "0.3.0-alpha.18" +futures = "0.3.1" proc-macro2 = "1.0.6" sqlx = { path = "../" } syn = "1.0" diff --git a/sqlx-macros/src/backend/mod.rs b/sqlx-macros/src/backend/mod.rs index 4dd0f4cb..5efdd5d7 100644 --- a/sqlx-macros/src/backend/mod.rs +++ b/sqlx-macros/src/backend/mod.rs @@ -1,14 +1,18 @@ use sqlx::Backend; pub trait BackendExt: Backend { + const BACKEND_PATH: &'static str; + fn param_type_for_id(id: &Self::TypeId) -> Option<&'static str>; fn return_type_for_id(id: &Self::TypeId) -> Option<&'static str>; } macro_rules! impl_backend_ext { - ($backend:ty { $($(#[$meta:meta])? $ty:ty $(| $borrowed:ty)?),* }) => { + ($backend:path { $($(#[$meta:meta])? $ty:ty $(| $borrowed:ty)?),* }) => { impl $crate::backend::BackendExt for $backend { + const BACKEND_PATH: &'static str = stringify!($backend); + fn param_type_for_id(id: &Self::TypeId) -> Option<&'static str> { use sqlx::types::TypeMetadata; diff --git a/sqlx-macros/src/lib.rs b/sqlx-macros/src/lib.rs index 234382cb..de3df772 100644 --- a/sqlx-macros/src/lib.rs +++ b/sqlx-macros/src/lib.rs @@ -4,7 +4,7 @@ use proc_macro::TokenStream; use proc_macro2::Span; -use quote::{quote, quote_spanned, ToTokens}; +use quote::{quote, quote_spanned, format_ident, ToTokens}; use syn::{ parse::{self, Parse, ParseStream}, @@ -16,7 +16,7 @@ use syn::{ use sqlx::HasTypeMetadata; -use tokio::runtime::Runtime; +use async_std::task; use std::fmt::Display; use url::Url; @@ -65,10 +65,7 @@ pub fn sql(input: TokenStream) -> TokenStream { eprintln!("expanding macro"); - match Runtime::new() - .map_err(Error::from) - .and_then(|runtime| runtime.block_on(process_sql(input))) - { + match task::block_on(process_sql(input)) { Ok(ts) => { eprintln!("emitting output: {}", ts); ts @@ -98,13 +95,11 @@ async fn process_sql(input: MacroInput) -> Result { ) .await } - #[cfg(feature = "mysql")] + #[cfg(feature = "mariadb")] "mysql" => { process_sql_with( input, - sqlx::Connection::::establish( - "postgresql://postgres@127.0.0.1/sqlx_test", - ) + sqlx::Connection::::establish(db_url.as_str()) .await .map_err(|e| format!("failed to connect to database: {}", e))?, ) @@ -153,7 +148,7 @@ where .unwrap(), ) }) - .ok_or_else(|| format!("unknown type ID: {}", type_).into()) + .ok_or_else(|| format!("unknown type param ID: {}", type_).into()) }) .collect::>>()?; @@ -162,7 +157,7 @@ where .iter() .map(|column| { Ok(::return_type_for_id(&column.type_id) - .ok_or_else(|| format!("unknown type ID: {}", &column.type_id))? + .ok_or_else(|| format!("unknown field type ID: {}", &column.type_id))? .parse::() .unwrap()) }) @@ -171,10 +166,13 @@ where let params = input.args.iter(); let params_ty_cons = input.args.iter().enumerate().map(|(i, expr)| { + // required or `quote!()` emits it as `Nusize` + let i = syn::Index::from(i); quote_spanned!( expr.span() => { use sqlx::TyConsExt as _; (sqlx::TyCons::new(¶ms.#i)).ty_cons() }) }); let query = &input.sql; + let backend_path = syn::parse_str::(DB::BACKEND_PATH).unwrap(); Ok(quote! {{ use sqlx::TyConsExt as _; @@ -185,7 +183,7 @@ where let _: (#(#param_types),*,) = (#(#params_ty_cons),*,); } - sqlx::CompiledSql::<_, (#(#output_types),*), sqlx::Postgres> { + sqlx::CompiledSql::<_, (#(#output_types),*), #backend_path> { query: #query, params, output: ::core::marker::PhantomData, diff --git a/src/connection.rs b/src/connection.rs index 9f73117d..52819631 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -152,7 +152,7 @@ where { Box::pin(async move { let mut live = self.0.acquire().await; - let result = live.raw.execute(query, params.into()).await; + let result = live.raw.execute(query, params.into_params()).await; self.0.release(live); result @@ -170,7 +170,7 @@ where { Box::pin(async_stream::try_stream! { let mut live = self.0.acquire().await; - let mut s = live.raw.fetch(query, params.into()); + let mut s = live.raw.fetch(query, params.into_params()); while let Some(row) = s.next().await.transpose()? { yield T::from_row(row); @@ -192,7 +192,7 @@ where { Box::pin(async move { let mut live = self.0.acquire().await; - let row = live.raw.fetch_optional(query, params.into()).await?; + let row = live.raw.fetch_optional(query, params.into_params()).await?; self.0.release(live); Ok(row.map(T::from_row)) diff --git a/src/describe.rs b/src/describe.rs index 2bfece12..271a80dd 100644 --- a/src/describe.rs +++ b/src/describe.rs @@ -2,18 +2,42 @@ use crate::Backend; use crate::types::HasTypeMetadata; +use std::fmt; + /// The result of running prepare + describe for the given backend. pub struct Describe { - /// /// The expected type IDs of bind parameters. pub param_types: Vec<::TypeId>, /// pub result_fields: Vec>, } +impl fmt::Debug for Describe + where ::TypeId: fmt::Debug, ResultField: fmt::Debug +{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("Describe") + .field("param_types", &self.param_types) + .field("result_fields", &self.result_fields) + .finish() + } +} + pub struct ResultField { pub name: Option, pub table_id: Option<::TableIdent>, /// The type ID of this result column. pub type_id: ::TypeId, } + +impl fmt::Debug for ResultField + where ::TableIdent: fmt::Debug, ::TypeId: fmt::Debug +{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("ResultField") + .field("name", &self.name) + .field("table_id", &self.table_id) + .field("type_id", &self.type_id) + .finish() + } +} diff --git a/src/io/buf_stream.rs b/src/io/buf_stream.rs index fb8f3b83..c3882c29 100644 --- a/src/io/buf_stream.rs +++ b/src/io/buf_stream.rs @@ -1,6 +1,7 @@ use bytes::{BufMut, BytesMut}; use std::io; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + +use async_std::io::prelude::*; pub struct BufStream { stream: S, @@ -17,7 +18,7 @@ pub struct BufStream { impl BufStream where - S: AsyncRead + AsyncWrite + Unpin, + S: Read + Write + Unpin, { pub fn new(stream: S) -> Self { Self { @@ -29,7 +30,9 @@ where } pub async fn close(&mut self) -> io::Result<()> { - self.stream.shutdown().await + use futures_util::io::AsyncWriteExt; + + self.stream.close().await } #[inline] diff --git a/src/mariadb/connection.rs b/src/mariadb/connection.rs index cadab8ce..04b884b0 100644 --- a/src/mariadb/connection.rs +++ b/src/mariadb/connection.rs @@ -23,7 +23,7 @@ use std::{ io, net::{IpAddr, SocketAddr}, }; -use tokio::net::TcpStream; +use async_std::net::TcpStream; use url::Url; pub struct MariaDbRawConnection { @@ -367,13 +367,13 @@ mod test { use super::*; use crate::{query::QueryParameters, Error, Pool}; - #[tokio::test] + #[async_std::test] async fn it_can_connect() -> Result<()> { MariaDbRawConnection::establish("mariadb://root@127.0.0.1:3306/test").await?; Ok(()) } - #[tokio::test] + #[async_std::test] async fn it_fails_to_connect_with_bad_username() -> Result<()> { match MariaDbRawConnection::establish("mariadb://roote@127.0.0.1:3306/test").await { Ok(_) => panic!("Somehow connected to database with incorrect username"), @@ -381,7 +381,7 @@ mod test { } } - #[tokio::test] + #[async_std::test] async fn it_can_ping() -> Result<()> { let mut conn = MariaDbRawConnection::establish("mariadb://root@127.0.0.1:3306/test").await?; @@ -389,15 +389,18 @@ mod test { Ok(()) } - #[tokio::test] + #[async_std::test] async fn it_can_describe() -> Result<()> { let mut conn = - MariaDbRawConnection::establish("mariadb://root@127.0.0.1:3306/test").await?; - conn.describe("SELECT id from users").await?; + MariaDbRawConnection::establish("mysql://sqlx_user@127.0.0.1:3306/sqlx_test").await?; + let describe = conn.describe("SELECT id from accounts where id = ?").await?; + + dbg!(describe); + Ok(()) } - #[tokio::test] + #[async_std::test] async fn it_can_create_mariadb_pool() -> Result<()> { let pool: Pool = Pool::new("mariadb://root@127.0.0.1:3306/test").await?; Ok(()) diff --git a/src/pool.rs b/src/pool.rs index 43587ab4..a0422256 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -289,7 +289,7 @@ where { Box::pin(async move { let mut live = self.0.acquire().await?; - let result = live.raw.execute(query, params.into()).await; + let result = live.raw.execute(query, params.into_params()).await; self.0.release(live); result @@ -307,7 +307,7 @@ where { Box::pin(async_stream::try_stream! { let mut live = self.0.acquire().await?; - let mut s = live.raw.fetch(query, params.into()); + let mut s = live.raw.fetch(query, params.into_params()); while let Some(row) = s.next().await.transpose()? { yield T::from_row(row); @@ -329,7 +329,7 @@ where { Box::pin(async move { let mut live = self.0.acquire().await?; - let row = live.raw.fetch_optional(query, params.into()).await?; + let row = live.raw.fetch_optional(query, params.into_params()).await?; self.0.release(live); diff --git a/src/postgres/connection.rs b/src/postgres/connection.rs index 46f81676..6f810992 100644 --- a/src/postgres/connection.rs +++ b/src/postgres/connection.rs @@ -159,7 +159,7 @@ mod tests { .unwrap() } - #[tokio::test] + #[async_std::test] #[ignore] async fn it_establishes() -> crate::Result<()> { let mut conn = PostgresRawConnection::establish(&database_url()).await?; @@ -173,7 +173,7 @@ mod tests { Ok(()) } - #[tokio::test] + #[async_std::test] #[ignore] async fn it_executes() -> crate::Result<()> { let mut conn = PostgresRawConnection::establish(&database_url()).await?; diff --git a/src/postgres/raw.rs b/src/postgres/raw.rs index bfe91d1a..c68c2c35 100644 --- a/src/postgres/raw.rs +++ b/src/postgres/raw.rs @@ -7,7 +7,7 @@ use crate::{ }; use byteorder::NetworkEndian; use std::{io, net::SocketAddr}; -use tokio::net::TcpStream; +use async_std::net::TcpStream; pub struct PostgresRawConnection { stream: BufStream, diff --git a/src/query.rs b/src/query.rs index 6efbffd2..88458590 100644 --- a/src/query.rs +++ b/src/query.rs @@ -17,7 +17,7 @@ pub trait IntoQueryParameters where DB: Backend, { - fn into(self) -> DB::QueryParameters; + fn into_params(self) -> DB::QueryParameters; } #[allow(unused)] @@ -28,7 +28,7 @@ macro_rules! impl_into_query_parameters { $($B: crate::types::HasSqlType<$T>,)+ $($T: crate::serialize::ToSql<$B>,)+ { - fn into(self) -> <$B as crate::backend::Backend>::QueryParameters { + fn into_params(self) -> <$B as crate::backend::Backend>::QueryParameters { let mut params = <<$B as crate::backend::Backend>::QueryParameters as crate::query::QueryParameters>::new(); @@ -45,7 +45,7 @@ where DB: Backend, { #[inline] - fn into(self) -> DB::QueryParameters { + fn into_params(self) -> DB::QueryParameters { self } } @@ -56,7 +56,7 @@ macro_rules! impl_into_query_parameters_for_backend { impl crate::query::IntoQueryParameters<$B> for () { #[inline] - fn into(self) -> <$B as crate::backend::Backend>::QueryParameters { + fn into_params(self) -> <$B as crate::backend::Backend>::QueryParameters { <<$B as crate::backend::Backend>::QueryParameters as crate::query::QueryParameters>::new() } diff --git a/tests/sql-macro-test.rs b/tests/sql-macro-test.rs index 5bcb2ee7..99647acf 100644 --- a/tests/sql-macro-test.rs +++ b/tests/sql-macro-test.rs @@ -1,12 +1,12 @@ #![feature(proc_macro_hygiene)] -#[tokio::test] +#[async_std::test] async fn test_sqlx_macro() -> sqlx::Result<()> { let conn = sqlx::Connection::::establish("postgres://postgres@127.0.0.1/sqlx_test") .await?; - let uuid: sqlx::Uuid = "256ba9c8-0048-11ea-b0f0-8f04859d047e".parse().unwrap(); - let accounts = sqlx_macros::sql!("SELECT * from accounts where id = $1", 5i64) + let uuid: sqlx::types::Uuid = "256ba9c8-0048-11ea-b0f0-8f04859d047e".parse().unwrap(); + let accounts = sqlx_macros::sql!("SELECT * from accounts where id = $1", None) .fetch_one(&conn) .await?;