From bbdc03c5760517e4664385aa43a7ee44533d4f1f Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Tue, 12 Nov 2019 20:57:55 -0800 Subject: [PATCH] unified prepared statement interface --- sqlx-macros/Cargo.toml | 4 +- sqlx-macros/src/lib.rs | 170 +++++++++++++++-------- sqlx-macros/src/postgres.rs | 45 ------ src/backend.rs | 12 +- src/bin/cargo-sqlx.rs | 17 ++- src/compiled.rs | 51 ++++--- src/connection.rs | 22 ++- src/error.rs | 6 +- src/lib.rs | 6 +- src/mariadb/backend.rs | 2 + src/mariadb/connection.rs | 112 +++++++++------ src/mariadb/error.rs | 21 +++ src/mariadb/mod.rs | 1 + src/mariadb/protocol/error_code.rs | 34 ++++- src/mariadb/protocol/response/err.rs | 11 +- src/mariadb/types/mod.rs | 44 +++++- src/postgres/backend.rs | 2 + src/postgres/connection.rs | 90 +++++++++--- src/postgres/error.rs | 19 ++- src/postgres/protocol/message.rs | 2 +- src/postgres/protocol/mod.rs | 35 +++-- src/postgres/protocol/row_description.rs | 9 +- src/postgres/raw.rs | 15 +- src/postgres/types/binary.rs | 41 ++++++ src/postgres/types/mod.rs | 62 ++++++++- src/prepared.rs | 32 +++-- src/serialize.rs | 6 +- src/types.rs | 32 ++++- tests/sql-macro-test.rs | 4 +- 29 files changed, 654 insertions(+), 253 deletions(-) delete mode 100644 sqlx-macros/src/postgres.rs create mode 100644 src/mariadb/error.rs create mode 100644 src/postgres/types/binary.rs diff --git a/sqlx-macros/Cargo.toml b/sqlx-macros/Cargo.toml index 0fde9a2b..00437c5c 100644 --- a/sqlx-macros/Cargo.toml +++ b/sqlx-macros/Cargo.toml @@ -8,14 +8,14 @@ edition = "2018" proc-macro = true [dependencies] +dotenv = "0.15.0" futures-preview = "0.3.0-alpha.18" -hex = "0.4.0" proc-macro2 = "1.0.6" sqlx = { path = "../", features = ["postgres"] } syn = "1.0" quote = "1.0" -sha2 = "0.8.0" tokio = { version = "0.2.0-alpha.4", default-features = false, features = [ "tcp" ] } +url = "2.1.0" [features] postgres = ["sqlx/postgres"] diff --git a/sqlx-macros/src/lib.rs b/sqlx-macros/src/lib.rs index 7ca8dbdc..8e68104e 100644 --- a/sqlx-macros/src/lib.rs +++ b/sqlx-macros/src/lib.rs @@ -6,47 +6,52 @@ use proc_macro2::Span; use quote::{format_ident, quote, quote_spanned, ToTokens}; -use syn::{parse_macro_input, Expr, ExprLit, Lit, LitStr, Token, Type}; -use syn::spanned::Spanned; -use syn::punctuated::Punctuated; -use syn::parse::{self, Parse, ParseStream}; +use syn::{ + parse::{self, Parse, ParseStream}, + parse_macro_input, + punctuated::Punctuated, + spanned::Spanned, + Expr, ExprLit, Lit, LitStr, Token, Type, +}; -use sha2::{Sha256, Digest}; -use sqlx::Postgres; +use sqlx::{HasTypeMetadata, Postgres}; use tokio::runtime::Runtime; -use std::error::Error as _; +use std::{error::Error as _, fmt::Display, str::FromStr}; +use url::Url; type Error = Box; type Result = std::result::Result; -mod postgres; - struct MacroInput { sql: String, sql_span: Span, - args: Vec + args: Vec, } impl Parse for MacroInput { fn parse(input: ParseStream) -> parse::Result { - let mut args = Punctuated::::parse_terminated(input)? - .into_iter(); + let mut args = Punctuated::::parse_terminated(input)?.into_iter(); let sql = match args.next() { - Some(Expr::Lit(ExprLit { lit: Lit::Str(sql), .. })) => sql, - Some(other_expr) => return Err(parse::Error::new_spanned(other_expr, "expected string literal")), + Some(Expr::Lit(ExprLit { + lit: Lit::Str(sql), .. + })) => sql, + Some(other_expr) => { + return Err(parse::Error::new_spanned( + other_expr, + "expected string literal", + )); + } None => return Err(input.error("expected SQL string literal")), }; - Ok( - MacroInput { - sql: sql.value(), - sql_span: sql.span(), - args: args.collect(), - } - ) + Ok(MacroInput { + sql: sql.value(), + sql_span: sql.span(), + args: args.collect(), + }) } } @@ -56,52 +61,109 @@ pub fn sql(input: TokenStream) -> TokenStream { eprintln!("expanding macro"); - match Runtime::new().map_err(Error::from).and_then(|runtime| runtime.block_on(process_sql(input))) { + match Runtime::new() + .map_err(Error::from) + .and_then(|runtime| runtime.block_on(process_sql(input))) + { Ok(ts) => { eprintln!("emitting output: {}", ts); ts - }, + } Err(e) => { if let Some(parse_err) = e.downcast_ref::() { return parse_err.to_compile_error().into(); } let msg = e.to_string(); - quote! ( compile_error!(#msg) ).into() + quote!(compile_error!(#msg)).into() } } } async fn process_sql(input: MacroInput) -> Result { - let hash = dbg!(hex::encode(&Sha256::digest(input.sql.as_bytes()))); + let db_url = Url::parse(&dotenv::var("DB_URL")?)?; - let conn = sqlx::Connection::::establish("postgresql://postgres@127.0.0.1/sqlx_test") - .await - .map_err(|e| format!("failed to connect to database: {}", e))?; + match db_url.scheme() { + #[cfg(feature = "postgres")] + "postgresql" => { + process_sql_with( + input, + sqlx::Connection::::establish(db_url.as_str()) + .await + .map_err(|e| format!("failed to connect to database: {}", e))?, + ) + .await + } + #[cfg(feature = "mysql")] + "mysql" => { + process_sql_with( + input, + sqlx::Connection::::establish( + "postgresql://postgres@127.0.0.1/sqlx_test", + ) + .await + .map_err(|e| format!("failed to connect to database: {}", e))?, + ) + .await + } + scheme => Err(format!("unexpected scheme {:?} in DB_URL {}", scheme, db_url).into()), + } +} +async fn process_sql_with( + input: MacroInput, + conn: sqlx::Connection, +) -> Result +where + ::TypeId: Display, +{ eprintln!("connection established"); - let prepared = conn.prepare(&hash, &input.sql) + let prepared = conn + .prepare(&input.sql) .await .map_err(|e| parse::Error::new(input.sql_span, e))?; if input.args.len() != prepared.param_types.len() { return Err(parse::Error::new( Span::call_site(), - format!("expected {} parameters, got {}", prepared.param_types.len(), input.args.len()) - ).into()); + format!( + "expected {} parameters, got {}", + prepared.param_types.len(), + input.args.len() + ), + ) + .into()); } - let param_types = prepared.param_types.iter().zip(&*input.args).map(|(type_, expr)| { - get_type_override(expr) - .or_else(|| postgres::map_param_type_oid(*type_)) - .ok_or_else(|| format!("unknown type OID: {}", type_).into()) - }) + let param_types = prepared + .param_types + .iter() + .zip(&*input.args) + .map(|(type_, expr)| { + get_type_override(expr) + .or_else(|| { + Some( + ::param_type_for_id(type_)? + .parse::() + .unwrap(), + ) + }) + .ok_or_else(|| format!("unknown type ID: {}", type_).into()) + }) .collect::>>()?; - let output_types = prepared.fields.iter().map(|field| { - postgres::map_output_type_oid(field.type_id) - }) + let output_types = prepared + .columns + .iter() + .map(|column| { + Ok( + ::return_type_for_id(&column.type_id) + .ok_or_else(|| format!("unknown type ID: {}", &column.type_id))? + .parse::() + .unwrap(), + ) + }) .collect::>>()?; let params = input.args.iter(); @@ -112,25 +174,23 @@ async fn process_sql(input: MacroInput) -> Result { let query = &input.sql; - Ok( - quote! {{ - use sqlx::TyConsExt as _; + Ok(quote! {{ + use sqlx::TyConsExt as _; - let params = (#(#params),*,); + let params = (#(#params),*,); - if false { - let _: (#(#param_types),*,) = (#(#params_ty_cons),*,); - } + if false { + let _: (#(#param_types),*,) = (#(#params_ty_cons),*,); + } - sqlx::CompiledSql::<_, (#(#output_types),*), sqlx::Postgres> { - query: #query, - params, - output: ::core::marker::PhantomData, - backend: ::core::marker::PhantomData, - } - }} - .into() - ) + sqlx::CompiledSql::<_, (#(#output_types),*), sqlx::Postgres> { + query: #query, + params, + output: ::core::marker::PhantomData, + backend: ::core::marker::PhantomData, + } + }} + .into()) } fn get_type_override(expr: &Expr) -> Option { diff --git a/sqlx-macros/src/postgres.rs b/sqlx-macros/src/postgres.rs deleted file mode 100644 index f9f79b95..00000000 --- a/sqlx-macros/src/postgres.rs +++ /dev/null @@ -1,45 +0,0 @@ -use proc_macro2::TokenStream; - -pub fn map_param_type_oid(oid: u32) -> Option { - Some(match oid { - 16 => "bool", - 1000 => "&[bool]", - 25 => "&str", - 1009 => "&[&str]", - 21 => "i16", - 1005 => "&[i16]", - 23 => "i32", - 1007 => "&[i32]", - 20 => "i64", - 1016 => "&[i64]", - 700 => "f32", - 1021 => "&[f32]", - 701 => "f64", - 1022 => "&[f64]", - 2950 => "sqlx::Uuid", - 2951 => "&[sqlx::Uuid]", - _ => return None - }.parse().unwrap()) -} - -pub fn map_output_type_oid(oid: u32) -> crate::Result { - Ok(match oid { - 16 => "bool", - 1000 => "Vec", - 25 => "String", - 1009 => "Vec", - 21 => "i16", - 1005 => "Vec", - 23 => "i32", - 1007 => "Vec", - 20 => "i64", - 1016 => "Vec", - 700 => "f32", - 1021 => "Vec", - 701 => "f64", - 1022 => "Vec", - 2950 => "sqlx::Uuid", - 2951 => "Vec", - _ => return Err(format!("unknown type ID: {}", oid).into()) - }.parse().unwrap()) -} diff --git a/src/backend.rs b/src/backend.rs index 08752d05..f0cc16d4 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -1,9 +1,9 @@ -use crate::{connection::RawConnection, query::QueryParameters, row::Row}; +use crate::{connection::RawConnection, query::QueryParameters, row::Row, types::HasTypeMetadata}; /// A database backend. /// /// This trait represents the concept of a backend (e.g. "MySQL" vs "SQLite"). -pub trait Backend: Sized { +pub trait Backend: HasTypeMetadata + Sized { /// The concrete `QueryParameters` implementation for this backend. type QueryParameters: QueryParameters; @@ -13,4 +13,12 @@ pub trait Backend: Sized { /// The concrete `Row` implementation for this backend. This type is returned /// from methods in the `RawConnection`. type Row: Row; + + /// The identifier for prepared statements; in Postgres this is a string + /// and in MariaDB/MySQL this is an integer. + type StatementIdent; + + /// The identifier for tables; in Postgres this is an `oid` while + /// in MariaDB/MySQL this is the qualified name of the table. + type TableIdent; } diff --git a/src/bin/cargo-sqlx.rs b/src/bin/cargo-sqlx.rs index c2540822..1d5ce833 100644 --- a/src/bin/cargo-sqlx.rs +++ b/src/bin/cargo-sqlx.rs @@ -1,5 +1,8 @@ -use std::{env, str}; -use std::io::{self, Write, Read}; +use std::{ + env, + io::{self, Read, Write}, + str, +}; use std::process::{Command, Stdio}; @@ -11,13 +14,17 @@ fn get_expanded_target() -> crate::Result> { let mut args = env::args_os().skip(2); - let cargo_args = args.by_ref().take_while(|arg| arg != "--").collect::>(); + let cargo_args = args + .by_ref() + .take_while(|arg| arg != "--") + .collect::>(); let rustc_args = args.collect::>(); let mut command = Command::new(cargo_path); - command.arg("rustc") + command + .arg("rustc") .args(cargo_args) .arg("--") .arg("-Z") @@ -61,7 +68,7 @@ fn find_next_sql_string(input: &str) -> Result> { let start = idx + STRING_START.len(); while let Some(end) = input[start..].find(STRING_END) { - if &input[start + end - 1 .. start + end] != "\\" { + if &input[start + end - 1..start + end] != "\\" { return Ok(Some(input[start..].split_at(end))); } } diff --git a/src/compiled.rs b/src/compiled.rs index a2836884..ed1d7184 100644 --- a/src/compiled.rs +++ b/src/compiled.rs @@ -1,9 +1,6 @@ +use crate::{query::IntoQueryParameters, Backend, Executor, FromSqlRow}; +use futures_core::{future::BoxFuture, stream::BoxStream, Stream}; use std::marker::PhantomData; -use crate::query::IntoQueryParameters; -use crate::{Backend, FromSqlRow, Executor}; -use futures_core::Stream; -use futures_core::stream::BoxStream; -use futures_core::future::BoxFuture; pub struct CompiledSql { #[doc(hidden)] @@ -12,30 +9,44 @@ pub struct CompiledSql { pub params: P, #[doc(hidden)] pub output: PhantomData, - pub backend: PhantomData + pub backend: PhantomData, } -impl CompiledSql where DB: Backend, P: IntoQueryParameters + Send, O: FromSqlRow + Send + Unpin { +impl CompiledSql +where + DB: Backend, + P: IntoQueryParameters + Send, + O: FromSqlRow + Send + Unpin, +{ #[inline] pub fn execute<'e, E: 'e>(self, executor: &'e E) -> BoxFuture<'e, crate::Result> - where - E: Executor, DB: 'e, P: 'e, O: 'e + where + E: Executor, + DB: 'e, + P: 'e, + O: 'e, { executor.execute(self.query, self.params) } #[inline] pub fn fetch<'e, E: 'e>(self, executor: &'e E) -> BoxStream<'e, crate::Result> - where - E: Executor, DB: 'e, P: 'e, O: 'e + where + E: Executor, + DB: 'e, + P: 'e, + O: 'e, { executor.fetch(self.query, self.params) } #[inline] pub fn fetch_all<'e, E: 'e>(self, executor: &'e E) -> BoxFuture<'e, crate::Result>> - where - E: Executor, DB: 'e, P: 'e, O: 'e + where + E: Executor, + DB: 'e, + P: 'e, + O: 'e, { executor.fetch_all(self.query, self.params) } @@ -45,16 +56,22 @@ impl CompiledSql where DB: Backend, P: IntoQueryParameters BoxFuture<'e, crate::Result>> - where - E: Executor, DB: 'e, P: 'e, O: 'e + where + E: Executor, + DB: 'e, + P: 'e, + O: 'e, { executor.fetch_optional(self.query, self.params) } #[inline] pub fn fetch_one<'e, E: 'e>(self, executor: &'e E) -> BoxFuture<'e, crate::Result> - where - E: Executor, DB: 'e, P: 'e, O: 'e + where + E: Executor, + DB: 'e, + P: 'e, + O: 'e, { executor.fetch_one(self.query, self.params) } diff --git a/src/connection.rs b/src/connection.rs index 18c43f78..a7fd68fb 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -3,6 +3,7 @@ use crate::{ error::Error, executor::Executor, pool::{Live, SharedPool}, + prepared::PreparedStatement, query::{IntoQueryParameters, QueryParameters}, row::FromSqlRow, }; @@ -19,7 +20,6 @@ use std::{ }, time::Instant, }; -use crate::prepared::PreparedStatement; /// A connection.bak to the database. /// @@ -73,10 +73,15 @@ pub trait RawConnection: Send { params: ::QueryParameters, ) -> crate::Result::Row>>; - async fn prepare(&mut self, name: &str, body: &str) -> crate::Result { - // TODO: implement for other backends - unimplemented!() - } + async fn prepare( + &mut self, + query: &str, + ) -> crate::Result<::StatementIdent>; + + async fn prepare_describe( + &mut self, + query: &str, + ) -> crate::Result>; } pub struct Connection(Arc>) @@ -128,9 +133,12 @@ where } /// Prepares a statement. - pub async fn prepare(&self, name: &str, body: &str) -> crate::Result { + /// + /// UNSTABLE: for use by sqlx-macros only + #[doc(hidden)] + pub async fn prepare(&self, body: &str) -> crate::Result> { let mut live = self.0.acquire().await; - let ret = live.raw.prepare(name, body).await?; + let ret = live.raw.prepare_describe(body).await?; self.0.release(live); Ok(ret) } diff --git a/src/error.rs b/src/error.rs index 96f66a61..4c668c59 100644 --- a/src/error.rs +++ b/src/error.rs @@ -54,7 +54,7 @@ impl Display for Error { match self { Error::Io(error) => write!(f, "{}", error), - Error::Database(error) => f.write_str(error.message()), + Error::Database(error) => Display::fmt(error, f), Error::NotFound => f.write_str("found no rows when we expected at least one"), @@ -85,8 +85,6 @@ where } /// An error that was returned by the database backend. -pub trait DatabaseError: Debug + Send + Sync { +pub trait DatabaseError: Display + Debug + Send + Sync { fn message(&self) -> &str; - - // TODO: Expose more error properties } diff --git a/src/lib.rs b/src/lib.rs index 857b6cf5..6e793ca7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -33,12 +33,11 @@ mod compiled; #[doc(inline)] pub use self::{ backend::Backend, - connection::Connection, compiled::CompiledSql, + connection::Connection, deserialize::FromSql, error::{Error, Result}, executor::Executor, - prepared::{PreparedStatement, Field}, pool::Pool, row::{FromSqlRow, Row}, serialize::ToSql, @@ -46,6 +45,9 @@ pub use self::{ types::HasSqlType, }; +#[doc(hidden)] +pub use types::HasTypeMetadata; + #[cfg(feature = "mariadb")] pub mod mariadb; diff --git a/src/mariadb/backend.rs b/src/mariadb/backend.rs index ae37b04d..c5295ad3 100644 --- a/src/mariadb/backend.rs +++ b/src/mariadb/backend.rs @@ -7,6 +7,8 @@ impl Backend for MariaDb { type QueryParameters = super::MariaDbQueryParameters; type RawConnection = super::MariaDbRawConnection; type Row = super::MariaDbRow; + type StatementIdent = u32; + type TableIdent = String; } impl_from_sql_row_tuples_for_backend!(MariaDb); diff --git a/src/mariadb/connection.rs b/src/mariadb/connection.rs index ce4921fe..233b8db0 100644 --- a/src/mariadb/connection.rs +++ b/src/mariadb/connection.rs @@ -11,11 +11,13 @@ use crate::{ }, MariaDb, MariaDbQueryParameters, MariaDbRow, }, - Backend, Error, Result, + prepared::{Column, PreparedStatement}, + Backend, Error, PreparedStatement, Result, }; use async_trait::async_trait; use byteorder::{ByteOrder, LittleEndian}; use futures_core::{future::BoxFuture, stream::BoxStream}; +use futures_util::stream::{self, StreamExt}; use std::{ future::Future, io, @@ -173,51 +175,33 @@ impl MariaDbRawConnection { }) } - // This should not be used by the user. It's mean for `RawConnection` impl - // This assumes the buffer has been set and all it needs is a flush - async fn exec_prepare(&mut self) -> Result { - self.stream.flush().await?; - - // COM_STMT_PREPARE returns COM_STMT_PREPARE_OK (0x00) or ERR (0xFF) - let mut packet = self.receive().await?; - let ok = match packet[0] { - 0xFF => { - let err = ErrPacket::decode(packet)?; - - // TODO: Bubble as Error::Database - // panic!("received db err = {:?}", err); - return Err( - io::Error::new(io::ErrorKind::InvalidInput, format!("{:?}", err)).into(), - ); - } - - _ => ComStmtPrepareOk::decode(packet)?, - }; - - // Skip decoding Column Definition packets for the result from a prepare statement - for _ in 0..ok.columns { - let _ = self.receive().await?; - } - - if ok.columns > 0 - && !self - .capabilities - .contains(Capabilities::CLIENT_DEPRECATE_EOF) + async fn check_eof(&mut self) -> Result<()> { + if !self + .capabilities + .contains(Capabilities::CLIENT_DEPRECATE_EOF) { - // TODO: Should we do something with the warning indicators here? - let _eof = EofPacket::decode(self.receive().await?)?; + let _ = EofPacket::decode(self.receive().await?)?; } - Ok(ok.statement_id) + Ok(()) } - async fn prepare<'c>(&'c mut self, statement: &'c str) -> Result { + async fn send_prepare<'c>(&'c mut self, statement: &'c str) -> Result { self.stream.flush().await?; self.start_sequence(); self.write(ComStmtPrepare { statement }); - self.exec_prepare().await + self.stream.flush().await?; + + // COM_STMT_PREPARE returns COM_STMT_PREPARE_OK (0x00) or ERR (0xFF) + let packet = self.receive().await?; + + if packet[0] == 0xFF { + return Err(ErrPacket::decode(packet)?.into()); + } + + ComStmtPrepareOk::decode(packet).map_err(Into::into) } async fn execute(&mut self, statement_id: u32, params: MariaDbQueryParameters) -> Result { @@ -323,11 +307,9 @@ impl RawConnection for MariaDbRawConnection { async fn execute(&mut self, query: &str, params: MariaDbQueryParameters) -> crate::Result { // Write prepare statement to buffer self.start_sequence(); - self.write(ComStmtPrepare { statement: query }); + let prepare_ok = self.send_prepare(query).await?; - let statement_id = self.exec_prepare().await?; - - let affected = self.execute(statement_id, params).await?; + let affected = self.execute(prepare_ok.statement_id, params).await?; Ok(affected) } @@ -347,6 +329,56 @@ impl RawConnection for MariaDbRawConnection { ) -> crate::Result::Row>> { unimplemented!(); } + + async fn prepare(&mut self, query: &str) -> crate::Result { + let prepare_ok = self.send_prepare(query).await?; + + for _ in 0..prepare_ok.params { + let _ = self.receive().await?; + } + + self.check_eof().await?; + + for _ in 0..prepare_ok.columns { + let _ = self.receive().await?; + } + + self.check_eof().await?; + + Ok(prepare_ok.statement_id) + } + + async fn prepare_describe(&mut self, query: &str) -> crate::Result> { + let prepare_ok = self.send_prepare(query).await?; + + let mut param_types = Vec::with_capacity(prepare_ok.params as usize); + + for _ in 0..prepare_ok.params { + let param = ColumnDefinitionPacket::decode(self.receive().await?)?; + param_types.push(param.field_type.0); + } + + self.check_eof().await?; + + let mut columns = Vec::with_capacity(prepare_ok.columns as usize); + + for _ in 0..prepare_ok.columns { + let column = ColumnDefinitionPacket::decode(self.receive().await?)?; + columns.push(Column { + name: column.column_alias.or(column.column), + table_id: column.table_alias.or(column.table), + type_id: column.field_type.0, + }) + } + + self.check_eof().await?; + + Ok(PreparedStatement { + identifier: prepare_ok.statement_id, + param_types, + columns, + }) + } } #[cfg(test)] diff --git a/src/mariadb/error.rs b/src/mariadb/error.rs new file mode 100644 index 00000000..0701c669 --- /dev/null +++ b/src/mariadb/error.rs @@ -0,0 +1,21 @@ +use crate::{error::DatabaseError, mariadb::protocol::ErrorCode}; + +use std::fmt; + +#[derive(Debug)] +pub struct Error { + pub code: ErrorCode, + pub message: Box, +} + +impl DatabaseError for Error { + fn message(&self) -> &str { + &self.message + } +} + +impl fmt::Display for ErrorCode { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "MariaDB returned an error: {}",) + } +} diff --git a/src/mariadb/mod.rs b/src/mariadb/mod.rs index 602b179b..12345aa8 100644 --- a/src/mariadb/mod.rs +++ b/src/mariadb/mod.rs @@ -1,5 +1,6 @@ mod backend; mod connection; +mod error; mod establish; mod io; mod protocol; diff --git a/src/mariadb/protocol/error_code.rs b/src/mariadb/protocol/error_code.rs index f6096a6e..fd8b4e09 100644 --- a/src/mariadb/protocol/error_code.rs +++ b/src/mariadb/protocol/error_code.rs @@ -1,10 +1,40 @@ +use std::fmt; + #[derive(Default, Debug)] pub struct ErrorCode(pub(crate) u16); -// TODO: It would be nice to figure out a clean way to go from 1152 to "ER_ABORTING_CONNECTION (1152)" in Debug. +use crate::error::DatabaseError; +use bitflags::_core::fmt::{Error, Formatter}; + +macro_rules! error_code_impl { + ($(const $name:ident: ErrorCode = ErrorCode($code:expr));*;) => { + impl ErrorCode { + $(const $name: ErrorCode = ErrorCode($code);)* + + pub fn code_name(&self) -> &'static str { + match self.0 { + $($code => $name,)* + _ => "" + } + } + } + } +} + +impl fmt::Debug for ErrorCode { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "ErrorCode({} [()])",) + } +} + +impl fmt::Display for ErrorCode { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{} ({})", self.code_name(), self.0) + } +} // Values from https://mariadb.com/kb/en/library/mariadb-error-codes/ -impl ErrorCode { +error_code_impl! { const ER_ABORTING_CONNECTION: ErrorCode = ErrorCode(1152); const ER_ACCESS_DENIED_CHANGE_USER_ERROR: ErrorCode = ErrorCode(1873); const ER_ACCESS_DENIED_ERROR: ErrorCode = ErrorCode(1045); diff --git a/src/mariadb/protocol/response/err.rs b/src/mariadb/protocol/response/err.rs index 8e5144ce..9194638c 100644 --- a/src/mariadb/protocol/response/err.rs +++ b/src/mariadb/protocol/response/err.rs @@ -1,6 +1,6 @@ use crate::{ io::Buf, - mariadb::{io::BufExt, protocol::ErrorCode}, + mariadb::{error::Error, io::BufExt, protocol::ErrorCode}, }; use byteorder::LittleEndian; use std::io; @@ -66,6 +66,15 @@ impl ErrPacket { }) } } + + pub fn expect_error(self) -> crate::Result { + match self { + ErrPacket::Progress { .. } => { + Err(format!("expected ErrPacket::Err, got {:?}", self).into()) + } + ErrPacket::Error { code, message, .. } => Err(Error { code, message }.into()), + } + } } #[cfg(test)] diff --git a/src/mariadb/types/mod.rs b/src/mariadb/types/mod.rs index 5e4f1892..66165979 100644 --- a/src/mariadb/types/mod.rs +++ b/src/mariadb/types/mod.rs @@ -1,5 +1,8 @@ use super::protocol::{FieldType, ParameterFlag}; -use crate::{mariadb::MariaDb, types::TypeMetadata}; +use crate::{ + mariadb::MariaDb, + types::{HasTypeMetadata, TypeMetadata}, +}; pub mod boolean; pub mod character; @@ -11,6 +14,43 @@ pub struct MariaDbTypeMetadata { pub param_flag: ParameterFlag, } -impl TypeMetadata for MariaDb { +impl HasTypeMetadata for MariaDb { type TypeMetadata = MariaDbTypeMetadata; + type TypeId = u8; + + fn param_type_for_id(id: &Self::TypeId) -> Option<&'static str> { + Some(match FieldType(*id) { + FieldType::MYSQL_TYPE_TINY => "i8", + FieldType::MYSQL_TYPE_SHORT => "i16", + FieldType::MYSQL_TYPE_LONG => "i32", + FieldType::MYSQL_TYPE_LONGLONG => "i64", + FieldType::MYSQL_TYPE_VAR_STRING => "&str", + FieldType::MYSQL_TYPE_FLOAT => "f32", + FieldType::MYSQL_TYPE_DOUBLE => "f64", + FieldType::MYSQL_TYPE_BLOB => "&[u8]", + _ => return None + }) + } + + fn return_type_for_id(id: &Self::TypeId) -> Option<&'static str> { + Some(match FieldType(*id) { + FieldType::MYSQL_TYPE_TINY => "i8", + FieldType::MYSQL_TYPE_SHORT => "i16", + FieldType::MYSQL_TYPE_LONG => "i32", + FieldType::MYSQL_TYPE_LONGLONG => "i64", + FieldType::MYSQL_TYPE_VAR_STRING => "String", + FieldType::MYSQL_TYPE_FLOAT => "f32", + FieldType::MYSQL_TYPE_DOUBLE => "f64", + FieldType::MYSQL_TYPE_BLOB => "Vec", + _ => return None + }) + } +} + +impl TypeMetadata for MariaDbTypeMetadata { + type TypeId = u8; + + fn type_id(&self) -> &Self::TypeId { + &self.field_type.0 + } } diff --git a/src/postgres/backend.rs b/src/postgres/backend.rs index 5a24ee17..81cbe85c 100644 --- a/src/postgres/backend.rs +++ b/src/postgres/backend.rs @@ -7,6 +7,8 @@ impl Backend for Postgres { type QueryParameters = super::PostgresQueryParameters; type RawConnection = super::PostgresRawConnection; type Row = super::PostgresRow; + type StatementIdent = String; + type TableIdent = u32; } impl_from_sql_row_tuples_for_backend!(Postgres); diff --git a/src/postgres/connection.rs b/src/postgres/connection.rs index ef6725b2..98c330c8 100644 --- a/src/postgres/connection.rs +++ b/src/postgres/connection.rs @@ -1,10 +1,19 @@ use super::{Postgres, PostgresQueryParameters, PostgresRawConnection, PostgresRow}; -use crate::{connection::RawConnection, postgres::raw::Step, url::Url, Error}; -use crate::query::QueryParameters; +use crate::{ + connection::RawConnection, + postgres::{error::ProtocolError, raw::Step}, + prepared::{Column, PreparedStatement}, + query::QueryParameters, + url::Url, + Error, +}; use async_trait::async_trait; use futures_core::stream::BoxStream; -use crate::prepared::{PreparedStatement, Field}; -use crate::postgres::error::ProtocolError; + +use std::sync::atomic::{AtomicU64, Ordering}; + +use crate::postgres::{protocol::Message, PostgresDatabaseError}; +use std::hash::Hasher; #[async_trait] impl RawConnection for PostgresRawConnection { @@ -96,45 +105,80 @@ impl RawConnection for PostgresRawConnection { Ok(row) } - async fn prepare(&mut self, name: &str, body: &str) -> crate::Result { - self.parse(name, body, &PostgresQueryParameters::new()); - self.describe(name); + async fn prepare(&mut self, body: &str) -> crate::Result { + let name = gen_statement_name(body); + self.parse(&name, body, &PostgresQueryParameters::new()); + + match self.receive().await? { + Some(Message::Response(response)) => Err(PostgresDatabaseError(response).into()), + Some(Message::ParseComplete) => Ok(name), + Some(message) => { + Err(ProtocolError(format!("unexpected message: {:?}", message)).into()) + } + None => Err(ProtocolError("expected ParseComplete or ErrorResponse").into()), + } + } + + async fn prepare_describe(&mut self, body: &str) -> crate::Result> { + let name = gen_statement_name(body); + self.parse(&name, body, &PostgresQueryParameters::new()); + self.describe(&name); self.sync().await?; - let param_desc= loop { - let step = self.step().await? + let param_desc = loop { + let step = self + .step() + .await? .ok_or(ProtocolError("did not receive ParameterDescription")); - if let Step::ParamDesc(desc) = dbg!(step)? - { - break desc; - } + if let Step::ParamDesc(desc) = step? { + break desc; + } }; let row_desc = loop { - let step = self.step().await? + let step = self + .step() + .await? .ok_or(ProtocolError("did not receive RowDescription")); - if let Step::RowDesc(desc) = dbg!(step)? - { + if let Step::RowDesc(desc) = step? { break desc; } }; Ok(PreparedStatement { - name: name.into(), - param_types: param_desc.ids, - fields: row_desc.fields.into_vec().into_iter() - .map(|field| Field { - name: field.name, - table_id: field.table_id, - type_id: field.type_id + identifier: name.into(), + param_types: param_desc.ids.into_vec(), + columns: row_desc + .fields + .into_vec() + .into_iter() + .map(|field| Column { + name: Some(field.name), + table_id: Some(field.table_id), + type_id: field.type_id, }) .collect(), }) } } +static STATEMENT_COUNT: AtomicU64 = AtomicU64::new(0); + +fn gen_statement_name(query: &str) -> String { + // hasher with no external dependencies + use std::collections::hash_map::DefaultHasher; + + let mut hasher = DefaultHasher::new(); + // including a global counter should help prevent collision + // with queries with the same content + hasher.write_u64(STATEMENT_COUNT.fetch_add(1, Ordering::SeqCst)); + hasher.write(query.as_bytes()); + + format!("sqlx_stmt_{:x}", hasher.finish()) +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/postgres/error.rs b/src/postgres/error.rs index fbd00f2a..d06241a7 100644 --- a/src/postgres/error.rs +++ b/src/postgres/error.rs @@ -1,7 +1,10 @@ use super::protocol::Response; use crate::error::DatabaseError; -use std::borrow::Cow; -use std::fmt::Debug; +use bitflags::_core::fmt::{Error, Formatter}; +use std::{ + borrow::Cow, + fmt::{self, Debug, Display}, +}; #[derive(Debug)] pub struct PostgresDatabaseError(pub(super) Box); @@ -15,8 +18,20 @@ impl DatabaseError for PostgresDatabaseError { } } +impl Display for PostgresDatabaseError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.pad(self.message()) + } +} + impl + Debug + Send + Sync> DatabaseError for ProtocolError { fn message(&self) -> &str { self.0.as_ref() } } + +impl> Display for ProtocolError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.pad(self.0.as_ref()) + } +} diff --git a/src/postgres/protocol/message.rs b/src/postgres/protocol/message.rs index bc6c0a58..4f84164d 100644 --- a/src/postgres/protocol/message.rs +++ b/src/postgres/protocol/message.rs @@ -21,5 +21,5 @@ pub enum Message { NoData, PortalSuspended, ParameterDescription(Box), - RowDescription(Box) + RowDescription(Box), } diff --git a/src/postgres/protocol/mod.rs b/src/postgres/protocol/mod.rs index 9e81a143..d9be891c 100644 --- a/src/postgres/protocol/mod.rs +++ b/src/postgres/protocol/mod.rs @@ -31,10 +31,22 @@ mod terminate; // TODO: mod ssl_request; pub use self::{ - bind::Bind, cancel_request::CancelRequest, close::Close, copy_data::CopyData, - copy_done::CopyDone, copy_fail::CopyFail, describe::Describe, describe::DescribeKind, encode::Encode, execute::Execute, - flush::Flush, parse::Parse, password_message::PasswordMessage, query::Query, - startup_message::StartupMessage, sync::Sync, terminate::Terminate, + bind::Bind, + cancel_request::CancelRequest, + close::Close, + copy_data::CopyData, + copy_done::CopyDone, + copy_fail::CopyFail, + describe::{Describe, DescribeKind}, + encode::Encode, + execute::Execute, + flush::Flush, + parse::Parse, + password_message::PasswordMessage, + query::Query, + startup_message::StartupMessage, + sync::Sync, + terminate::Terminate, }; mod authentication; @@ -54,10 +66,17 @@ mod row_description; mod message; pub use self::{ - authentication::Authentication, backend_key_data::BackendKeyData, - command_complete::CommandComplete, data_row::DataRow, decode::Decode, message::Message, - notification_response::NotificationResponse, parameter_description::ParameterDescription, - parameter_status::ParameterStatus, ready_for_query::ReadyForQuery, response::Response, + authentication::Authentication, + backend_key_data::BackendKeyData, + command_complete::CommandComplete, + data_row::DataRow, + decode::Decode, + message::Message, + notification_response::NotificationResponse, + parameter_description::ParameterDescription, + parameter_status::ParameterStatus, + ready_for_query::ReadyForQuery, + response::Response, row_description::{RowDescription, RowField}, }; diff --git a/src/postgres/protocol/row_description.rs b/src/postgres/protocol/row_description.rs index 9ad529be..4c38293c 100644 --- a/src/postgres/protocol/row_description.rs +++ b/src/postgres/protocol/row_description.rs @@ -1,8 +1,7 @@ use super::Decode; use crate::io::Buf; use byteorder::NetworkEndian; -use std::io; -use std::io::BufRead; +use std::{io, io::BufRead}; #[derive(Debug)] pub struct RowDescription { @@ -17,7 +16,7 @@ pub struct RowField { pub type_id: u32, pub type_size: i16, pub type_mod: i32, - pub format_code: i16 + pub format_code: i16, } impl Decode for RowDescription { @@ -26,7 +25,7 @@ impl Decode for RowDescription { let mut fields = Vec::with_capacity(cnt); for _ in 0..cnt { - fields.push(dbg!(RowField { + fields.push(RowField { name: super::read_string(&mut buf)?, table_id: buf.get_u32::()?, attr_num: buf.get_i16::()?, @@ -34,7 +33,7 @@ impl Decode for RowDescription { type_size: buf.get_i16::()?, type_mod: buf.get_i32::()?, format_code: buf.get_i16::()?, - })); + }); } Ok(Self { diff --git a/src/postgres/raw.rs b/src/postgres/raw.rs index 110ad33a..bfe91d1a 100644 --- a/src/postgres/raw.rs +++ b/src/postgres/raw.rs @@ -151,8 +151,9 @@ impl PostgresRawConnection { pub(super) fn describe(&mut self, statement: &str) { protocol::Describe { kind: protocol::DescribeKind::PreparedStatement, - name: statement - }.encode(self.stream.buffer_mut()) + name: statement, + } + .encode(self.stream.buffer_mut()) } pub(super) fn bind(&mut self, portal: &str, statement: &str, params: &PostgresQueryParameters) { @@ -198,15 +199,15 @@ impl PostgresRawConnection { Message::ReadyForQuery(_) => { return Ok(None); - }, + } Message::ParameterDescription(desc) => { return Ok(Some(Step::ParamDesc(desc))); - }, + } Message::RowDescription(desc) => { return Ok(Some(Step::RowDesc(desc))); - }, + } message => { return Err(io::Error::new( @@ -260,9 +261,7 @@ impl PostgresRawConnection { b't' => Message::ParameterDescription(Box::new( protocol::ParameterDescription::decode(body)?, )), - b'T' => Message::RowDescription(Box::new( - protocol::RowDescription::decode(body)? - )), + b'T' => Message::RowDescription(Box::new(protocol::RowDescription::decode(body)?)), id => { return Err(io::Error::new( diff --git a/src/postgres/types/binary.rs b/src/postgres/types/binary.rs new file mode 100644 index 00000000..4e3751a0 --- /dev/null +++ b/src/postgres/types/binary.rs @@ -0,0 +1,41 @@ +use crate::{ + postgres::types::{PostgresTypeFormat, PostgresTypeMetadata}, + serialize::IsNull, + types::TypeMetadata, + FromSql, HasSqlType, Postgres, ToSql, +}; + +impl HasSqlType<[u8]> for Postgres { + fn metadata() -> Self::TypeMetadata { + PostgresTypeMetadata { + format: PostgresTypeFormat::Binary, + oid: 17, + array_oid: 1001, + } + } +} + +impl HasSqlType> for Postgres { + fn metadata() -> Self::TypeMetadata { + >::metadata() + } +} + +impl ToSql for [u8] { + fn to_sql(&self, buf: &mut Vec) -> IsNull { + buf.extend_from_slice(self); + IsNull::No + } +} + +impl ToSql for Vec { + fn to_sql(&self, buf: &mut Vec) -> IsNull { + <[u8] as ToSql>::to_sql(self, buf) + } +} + +impl FromSql for Vec { + fn from_sql(raw: Option<&[u8]>) -> Self { + raw.unwrap().into() + } +} diff --git a/src/postgres/types/mod.rs b/src/postgres/types/mod.rs index 0f92985a..38a49541 100644 --- a/src/postgres/types/mod.rs +++ b/src/postgres/types/mod.rs @@ -28,9 +28,12 @@ //! | `Uuid` (`uuid` feature) | UUID | use super::Postgres; -use crate::types::TypeMetadata; -use crate::HasSqlType; +use crate::{ + types::{HasTypeMetadata, TypeMetadata}, + HasSqlType, +}; +mod binary; mod boolean; mod character; mod numeric; @@ -54,6 +57,59 @@ pub struct PostgresTypeMetadata { pub array_oid: u32, } -impl TypeMetadata for Postgres { +impl HasTypeMetadata for Postgres { + type TypeId = u32; type TypeMetadata = PostgresTypeMetadata; + + fn param_type_for_id(id: &Self::TypeId) -> Option<&'static str> { + Some(match id { + 16 => "bool", + 1000 => "&[bool]", + 25 => "&str", + 1009 => "&[&str]", + 21 => "i16", + 1005 => "&[i16]", + 23 => "i32", + 1007 => "&[i32]", + 20 => "i64", + 1016 => "&[i64]", + 700 => "f32", + 1021 => "&[f32]", + 701 => "f64", + 1022 => "&[f64]", + 2950 => "sqlx::Uuid", + 2951 => "&[sqlx::Uuid]", + _ => return None, + }) + } + + fn return_type_for_id(id: &Self::TypeId) -> Option<&'static str> { + Some(match id { + 16 => "bool", + 1000 => "Vec", + 25 => "String", + 1009 => "Vec", + 21 => "i16", + 1005 => "Vec", + 23 => "i32", + 1007 => "Vec", + 20 => "i64", + 1016 => "Vec", + 700 => "f32", + 1021 => "Vec", + 701 => "f64", + 1022 => "Vec", + 2950 => "sqlx::Uuid", + 2951 => "Vec", + _ => return None, + }) + } +} + +impl TypeMetadata for PostgresTypeMetadata { + type TypeId = u32; + + fn type_id(&self) -> &u32 { + &self.oid + } } diff --git a/src/prepared.rs b/src/prepared.rs index 9292d4d4..e5e71476 100644 --- a/src/prepared.rs +++ b/src/prepared.rs @@ -1,13 +1,25 @@ -#[derive(Debug)] -pub struct PreparedStatement { - pub name: String, - pub param_types: Box<[u32]>, - pub fields: Vec, +use crate::{query::QueryParameters, Backend, Error, Executor, FromSqlRow, HasSqlType, ToSql}; + +use futures_core::{future::BoxFuture, stream::BoxStream}; +use std::marker::PhantomData; + +use crate::types::{HasTypeMetadata, TypeMetadata}; + +use std::fmt::{self, Debug}; + +/// A prepared statement. +pub struct PreparedStatement { + /// + pub identifier: ::StatementIdent, + /// The expected type IDs of bind parameters. + pub param_types: Vec<::TypeId>, + /// + pub columns: Vec>, } -#[derive(Debug)] -pub struct Field { - pub name: String, - pub table_id: u32, - pub type_id: u32, +pub struct Column { + pub name: Option, + pub table_id: Option<::TableIdent>, + /// The type ID of this result column. + pub type_id: ::TypeId, } diff --git a/src/serialize.rs b/src/serialize.rs index 905f2dfb..95f73641 100644 --- a/src/serialize.rs +++ b/src/serialize.rs @@ -46,9 +46,9 @@ where } impl ToSql for &'_ T - where - DB: Backend + HasSqlType, - T: ToSql, +where + DB: Backend + HasSqlType, + T: ToSql, { #[inline] fn to_sql(&self, buf: &mut Vec) -> IsNull { diff --git a/src/types.rs b/src/types.rs index 3995467b..d09b5e75 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,18 +1,39 @@ /// Information about how a backend stores metadata about /// given SQL types. -pub trait TypeMetadata { +pub trait HasTypeMetadata { /// The actual type used to represent metadata. - type TypeMetadata; + type TypeMetadata: TypeMetadata; + + /// The Rust type of the type ID for the backend. + type TypeId: Eq; + + /// UNSTABLE: for internal use only + #[doc(hidden)] + fn param_type_for_id(id: &Self::TypeId) -> Option<&'static str>; + + /// UNSTABLE: for internal use only + #[doc(hidden)] + fn return_type_for_id(id: &Self::TypeId) -> Option<&'static str>; +} + +pub trait TypeMetadata { + type TypeId: Eq; + + fn type_id(&self) -> &Self::TypeId; + fn type_id_eq(&self, id: &Self::TypeId) -> bool { + self.type_id() == id + } } /// Indicates that a SQL type exists for a backend and defines /// useful metadata for the backend. -pub trait HasSqlType: TypeMetadata { +pub trait HasSqlType: HasTypeMetadata { fn metadata() -> Self::TypeMetadata; } impl HasSqlType<&'_ A> for DB - where DB: HasSqlType +where + DB: HasSqlType, { fn metadata() -> Self::TypeMetadata { >::metadata() @@ -20,7 +41,8 @@ impl HasSqlType<&'_ A> for DB } impl HasSqlType> for DB - where DB: HasSqlType +where + DB: HasSqlType, { fn metadata() -> Self::TypeMetadata { >::metadata() diff --git a/tests/sql-macro-test.rs b/tests/sql-macro-test.rs index 3b599963..5bcb2ee7 100644 --- a/tests/sql-macro-test.rs +++ b/tests/sql-macro-test.rs @@ -2,7 +2,9 @@ #[tokio::test] async fn test_sqlx_macro() -> sqlx::Result<()> { - let conn = sqlx::Connection::::establish("postgres://postgres@127.0.0.1/sqlx_test").await?; + let conn = + sqlx::Connection::::establish("postgres://postgres@127.0.0.1/sqlx_test") + .await?; let uuid: sqlx::Uuid = "256ba9c8-0048-11ea-b0f0-8f04859d047e".parse().unwrap(); let accounts = sqlx_macros::sql!("SELECT * from accounts where id = $1", 5i64) .fetch_one(&conn)