diff --git a/Cargo.toml b/Cargo.toml index 5928ca64..d2c1e9cf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,6 +33,7 @@ url = "2.1.0" [dev-dependencies] matches = "0.1.8" tokio = { version = "0.2.0-alpha.4", default-features = false, features = [ "rt-full" ] } +sqlx-macros = { path = "sqlx-macros/" } [profile.release] lto = true diff --git a/sqlx-macros/Cargo.toml b/sqlx-macros/Cargo.toml new file mode 100644 index 00000000..1eae8dc9 --- /dev/null +++ b/sqlx-macros/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "sqlx-macros" +version = "0.1.0" +authors = ["Austin Bonander "] +edition = "2018" + +[lib] +proc-macro = true + +[dependencies] +futures-preview = "0.3.0-alpha.18" +hex = "0.4.0" +proc-macro2 = "1.0.6" +sqlx = { path = "../", features = ["postgres"] } +syn = "1.0" +quote = "1.0" +sha2 = "0.8.0" +tokio = { version = "0.2.0-alpha.4", default-features = false, features = [ "tcp" ] } + diff --git a/sqlx-macros/src/lib.rs b/sqlx-macros/src/lib.rs new file mode 100644 index 00000000..75b0b3ca --- /dev/null +++ b/sqlx-macros/src/lib.rs @@ -0,0 +1,46 @@ +extern crate proc_macro; + +use proc_macro::TokenStream; + +use quote::quote; + +use syn::parse_macro_input; + +use sha2::{Sha256, Digest}; +use sqlx::Postgres; + +use tokio::runtime::Runtime; + +type Error = Box; +type Result = std::result::Result; + +#[proc_macro] +pub fn sql(input: TokenStream) -> TokenStream { + let string = parse_macro_input!(input as syn::LitStr).value(); + + eprintln!("expanding macro"); + + match Runtime::new().map_err(Error::from).and_then(|runtime| runtime.block_on(process_sql(&string))) { + Ok(ts) => ts, + Err(e) => { + let msg = e.to_string(); + quote! ( compile_error!(#msg) ).into() + } + } +} + +async fn process_sql(sql: &str) -> Result { + let hash = dbg!(hex::encode(&Sha256::digest(sql.as_bytes()))); + + let conn = sqlx::Connection::::establish("postgresql://postgres@127.0.0.1/sqlx_test") + .await + .map_err(|e| format!("failed to connect to database: {}", e))?; + + eprintln!("connection established"); + + let prepared = conn.prepare(&hash, sql).await?; + + let msg = format!("{:?}", prepared); + + Ok(quote! { compile_error!(#msg) }.into()) +} diff --git a/src/connection.rs b/src/connection.rs index 9da34521..18c43f78 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -75,7 +75,7 @@ pub trait RawConnection: Send { async fn prepare(&mut self, name: &str, body: &str) -> crate::Result { // TODO: implement for other backends - Err("connection does not support prepare() operation".into()) + unimplemented!() } } @@ -130,7 +130,7 @@ where /// Prepares a statement. pub async fn prepare(&self, name: &str, body: &str) -> crate::Result { let mut live = self.0.acquire().await; - let ret = live.raw.prepare(name, body)?; + let ret = live.raw.prepare(name, body).await?; self.0.release(live); Ok(ret) } diff --git a/src/io/buf.rs b/src/io/buf.rs index ff73cfe8..65cc1b7a 100644 --- a/src/io/buf.rs +++ b/src/io/buf.rs @@ -9,6 +9,8 @@ pub trait Buf { fn get_u16(&mut self) -> io::Result; + fn get_i16(&mut self) -> io::Result; + fn get_u24(&mut self) -> io::Result; fn get_i32(&mut self) -> io::Result; @@ -42,6 +44,13 @@ impl<'a> Buf for &'a [u8] { Ok(val) } + fn get_i16(&mut self) -> io::Result { + let val = T::read_i16(*self); + self.advance(2); + + Ok(val) + } + fn get_i32(&mut self) -> io::Result { let val = T::read_i32(*self); self.advance(4); diff --git a/src/postgres/connection.rs b/src/postgres/connection.rs index a7da6f5f..ef6725b2 100644 --- a/src/postgres/connection.rs +++ b/src/postgres/connection.rs @@ -1,8 +1,10 @@ use super::{Postgres, PostgresQueryParameters, PostgresRawConnection, PostgresRow}; use crate::{connection::RawConnection, postgres::raw::Step, url::Url, Error}; +use crate::query::QueryParameters; use async_trait::async_trait; use futures_core::stream::BoxStream; use crate::prepared::{PreparedStatement, Field}; +use crate::postgres::error::ProtocolError; #[async_trait] impl RawConnection for PostgresRawConnection { @@ -94,21 +96,26 @@ impl RawConnection for PostgresRawConnection { Ok(row) } - fn prepare(&mut self, name: &str, body: &str) -> crate::Result { - self.parse(name, body, &[]); + async fn prepare(&mut self, name: &str, body: &str) -> crate::Result { + self.parse(name, body, &PostgresQueryParameters::new()); self.describe(name); + self.sync().await?; let param_desc= loop { - if let Step::ParamDesc(desc) = self.step().await? - .ok_or("did not receive ParameterDescription")? + let step = self.step().await? + .ok_or(ProtocolError("did not receive ParameterDescription")); + + if let Step::ParamDesc(desc) = dbg!(step)? { break desc; } }; let row_desc = loop { - if let Step::RowDesc(desc) = self.step().await? - .ok_or("did not receive RowDescription")? + let step = self.step().await? + .ok_or(ProtocolError("did not receive RowDescription")); + + if let Step::RowDesc(desc) = dbg!(step)? { break desc; } diff --git a/src/postgres/error.rs b/src/postgres/error.rs index 8be36995..fbd00f2a 100644 --- a/src/postgres/error.rs +++ b/src/postgres/error.rs @@ -1,11 +1,22 @@ use super::protocol::Response; use crate::error::DatabaseError; +use std::borrow::Cow; +use std::fmt::Debug; #[derive(Debug)] pub struct PostgresDatabaseError(pub(super) Box); +#[derive(Debug)] +pub struct ProtocolError(pub(super) T); + impl DatabaseError for PostgresDatabaseError { fn message(&self) -> &str { self.0.message() } } + +impl + Debug + Send + Sync> DatabaseError for ProtocolError { + fn message(&self) -> &str { + self.0.as_ref() + } +} diff --git a/src/postgres/protocol/describe.rs b/src/postgres/protocol/describe.rs index 6df51885..d042c312 100644 --- a/src/postgres/protocol/describe.rs +++ b/src/postgres/protocol/describe.rs @@ -9,11 +9,11 @@ pub enum DescribeKind { } pub struct Describe<'a> { - kind: DescribeKind, + pub kind: DescribeKind, /// The name of the prepared statement or portal to describe (an empty string selects the /// unnamed prepared statement or portal). - name: &'a str, + pub name: &'a str, } impl Encode for Describe<'_> { diff --git a/src/postgres/protocol/mod.rs b/src/postgres/protocol/mod.rs index d908c5dc..9e81a143 100644 --- a/src/postgres/protocol/mod.rs +++ b/src/postgres/protocol/mod.rs @@ -65,10 +65,10 @@ fn read_string(buf: &mut &[u8]) -> io::Result { let str_len = memchr::memchr(0u8, buf) .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "unterminated string"))?; - let string = str::from_utf8(&*buf[..str_len]) + let string = str::from_utf8(&buf[..str_len]) .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - *buf = &*buf[str_len + 1..]; + *buf = &buf[str_len + 1..]; Ok(string.to_owned()) } diff --git a/src/postgres/protocol/row_description.rs b/src/postgres/protocol/row_description.rs index c767f211..10740057 100644 --- a/src/postgres/protocol/row_description.rs +++ b/src/postgres/protocol/row_description.rs @@ -26,15 +26,15 @@ impl Decode for RowDescription { let mut fields = Vec::with_capacity(cnt); for _ in 0..cnt { - fields.push(RowField { + fields.push(dbg!(RowField { name: super::read_string(&mut buf)?, table_id: buf.get_u32::()?, attr_num: buf.get_i16::()?, type_id: buf.get_u32::()?, - type_size: buf.get_16::()?, + type_size: buf.get_i16::()?, type_mod: buf.get_i32::()?, format_code: buf.get_i16::()?, - }); + })); } Ok(Self { diff --git a/src/postgres/raw.rs b/src/postgres/raw.rs index 929c8c65..110ad33a 100644 --- a/src/postgres/raw.rs +++ b/src/postgres/raw.rs @@ -204,7 +204,9 @@ impl PostgresRawConnection { return Ok(Some(Step::ParamDesc(desc))); }, - Message:: + Message::RowDescription(desc) => { + return Ok(Some(Step::RowDesc(desc))); + }, message => { return Err(io::Error::new( @@ -296,6 +298,7 @@ impl PostgresRawConnection { } } +#[derive(Debug)] pub(super) enum Step { Command(u64), Row(PostgresRow), diff --git a/src/postgres/row.rs b/src/postgres/row.rs index d115b51e..361b49b1 100644 --- a/src/postgres/row.rs +++ b/src/postgres/row.rs @@ -1,6 +1,7 @@ use super::{protocol::DataRow, Postgres}; use crate::row::Row; +#[derive(Debug)] pub struct PostgresRow(pub(crate) DataRow); impl Row for PostgresRow { diff --git a/src/prepared.rs b/src/prepared.rs index e096f474..9292d4d4 100644 --- a/src/prepared.rs +++ b/src/prepared.rs @@ -1,9 +1,11 @@ +#[derive(Debug)] pub struct PreparedStatement { pub name: String, pub param_types: Box<[u32]>, pub fields: Vec, } +#[derive(Debug)] pub struct Field { pub name: String, pub table_id: u32, diff --git a/tests/sql-macro-test.rs b/tests/sql-macro-test.rs index 382c576b..1e4a5161 100644 --- a/tests/sql-macro-test.rs +++ b/tests/sql-macro-test.rs @@ -1,3 +1,5 @@ +#![feature(proc_macro_hygiene)] + fn main() { - sqlx::sql!("SELECT * from accounts"); + sqlx_macros::sql!("SELECT * from accounts"); }