From c040c97cb36949c967d94ff63c3357f4d8a009e6 Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Tue, 16 Jul 2019 01:15:12 -0700 Subject: [PATCH] Initial experiment with the low-level API and PREPARE + EXECUTE queries --- Cargo.toml | 5 + sqlx-postgres-protocol/src/authentication.rs | 4 +- sqlx-postgres-protocol/src/bind.rs | 80 +++++ sqlx-postgres-protocol/src/data_row.rs | 2 - sqlx-postgres-protocol/src/execute.rs | 31 ++ sqlx-postgres-protocol/src/lib.rs | 8 + sqlx-postgres-protocol/src/message.rs | 4 + sqlx-postgres-protocol/src/parse.rs | 42 +++ sqlx-postgres-protocol/src/sync.rs | 13 + sqlx-postgres/src/connection/execute.rs | 291 +++++++++++++++++++ sqlx-postgres/src/connection/mod.rs | 43 +-- sqlx-postgres/src/connection/query.rs | 34 --- src/main.rs | 14 +- 13 files changed, 501 insertions(+), 70 deletions(-) create mode 100644 sqlx-postgres-protocol/src/bind.rs create mode 100644 sqlx-postgres-protocol/src/execute.rs create mode 100644 sqlx-postgres-protocol/src/parse.rs create mode 100644 sqlx-postgres-protocol/src/sync.rs create mode 100644 sqlx-postgres/src/connection/execute.rs delete mode 100644 sqlx-postgres/src/connection/query.rs diff --git a/Cargo.toml b/Cargo.toml index 9ee346e9..57f45627 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,3 +25,8 @@ bytes = "0.4.12" lto = true codegen-units = 1 incremental = false + +[profile.release] +lto = true +codegen-units = 1 +incremental = false diff --git a/sqlx-postgres-protocol/src/authentication.rs b/sqlx-postgres-protocol/src/authentication.rs index 73e6189d..41cfaec6 100644 --- a/sqlx-postgres-protocol/src/authentication.rs +++ b/sqlx-postgres-protocol/src/authentication.rs @@ -45,13 +45,13 @@ impl Decode for Authentication { 0 => Authentication::Ok, 2 => Authentication::KerberosV5, 3 => Authentication::CleartextPassword, - + 5 => { let mut salt = [0_u8; 4]; salt.copy_from_slice(&src[1..5]); Authentication::Md5Password { salt } - }, + } 6 => Authentication::ScmCredential, 7 => Authentication::Gss, diff --git a/sqlx-postgres-protocol/src/bind.rs b/sqlx-postgres-protocol/src/bind.rs new file mode 100644 index 00000000..0ad1937c --- /dev/null +++ b/sqlx-postgres-protocol/src/bind.rs @@ -0,0 +1,80 @@ +use crate::Encode; +use byteorder::{BigEndian, ByteOrder}; +use std::io; + +pub struct Bind<'a> { + // The name of the destination portal (an empty string selects the unnamed portal). + portal: &'a str, + + // The name of the source prepared statement (an empty string selects the unnamed prepared statement). + statement: &'a str, + + // The parameter format codes. + formats: &'a [i16], + + // The values of the parameters. + // Arranged as: [len][size_0][value_0][size_1][value_1] etc... + buffer: &'a [u8], + + // The result-column format codes. + result_formats: &'a [i16], +} + +impl<'a> Bind<'a> { + pub fn new( + portal: &'a str, + statement: &'a str, + formats: &'a [i16], + buffer: &'a [u8], + result_formats: &'a [i16], + ) -> Self { + Self { + portal, + statement, + formats, + buffer, + result_formats, + } + } +} + +impl<'a> Encode for Bind<'a> { + fn encode(&self, buf: &mut Vec) -> io::Result<()> { + buf.push(b'B'); + + let pos = buf.len(); + buf.extend_from_slice(&[0, 0, 0, 0]); + + // portal + buf.extend_from_slice(self.portal.as_bytes()); + buf.push(b'\0'); + + // statement + buf.extend_from_slice(self.statement.as_bytes()); + buf.push(b'\0'); + + // formats.len + buf.extend_from_slice(&(self.formats.len() as i16).to_be_bytes()); + + // formats + for format in self.formats { + buf.extend_from_slice(&format.to_be_bytes()); + } + + // values + buf.extend_from_slice(&self.buffer); + + // result_formats.len + buf.extend_from_slice(&(self.result_formats.len() as i16).to_be_bytes()); + + // result_formats + for format in self.result_formats { + buf.extend_from_slice(&format.to_be_bytes()); + } + + let len = buf.len() - pos; + BigEndian::write_u32(&mut buf[pos..], len as u32); + + Ok(()) + } +} diff --git a/sqlx-postgres-protocol/src/data_row.rs b/sqlx-postgres-protocol/src/data_row.rs index 0458afc6..0648762a 100644 --- a/sqlx-postgres-protocol/src/data_row.rs +++ b/sqlx-postgres-protocol/src/data_row.rs @@ -92,8 +92,6 @@ mod tests { let src = Bytes::from_static(DATA_ROW); let message = DataRow::decode(src)?; - println!("{:?}", message); - assert_eq!(message.len(), 3); assert_eq!(message.get(0), Some(&b"1"[..])); diff --git a/sqlx-postgres-protocol/src/execute.rs b/sqlx-postgres-protocol/src/execute.rs new file mode 100644 index 00000000..d17254ee --- /dev/null +++ b/sqlx-postgres-protocol/src/execute.rs @@ -0,0 +1,31 @@ +use crate::Encode; +use std::io; + +pub struct Execute<'a> { + portal: &'a str, + limit: i32, +} + +impl<'a> Execute<'a> { + pub fn new(portal: &'a str, limit: i32) -> Self { + Self { portal, limit } + } +} + +impl<'a> Encode for Execute<'a> { + fn encode(&self, buf: &mut Vec) -> io::Result<()> { + buf.push(b'E'); + + let len = 4 + self.portal.len() + 1 + 4; + buf.extend_from_slice(&(len as i32).to_be_bytes()); + + // portal + buf.extend_from_slice(self.portal.as_bytes()); + buf.push(b'\0'); + + // limit + buf.extend_from_slice(&self.limit.to_be_bytes()); + + Ok(()) + } +} diff --git a/sqlx-postgres-protocol/src/lib.rs b/sqlx-postgres-protocol/src/lib.rs index 20ee7238..f6111756 100644 --- a/sqlx-postgres-protocol/src/lib.rs +++ b/sqlx-postgres-protocol/src/lib.rs @@ -2,34 +2,42 @@ mod authentication; mod backend_key_data; +mod bind; mod command_complete; mod data_row; mod decode; mod encode; +mod execute; mod message; mod parameter_status; +mod parse; mod password_message; mod query; mod ready_for_query; mod response; mod row_description; mod startup_message; +mod sync; mod terminate; pub use self::{ authentication::Authentication, backend_key_data::BackendKeyData, + bind::Bind, command_complete::CommandComplete, data_row::DataRow, decode::Decode, encode::Encode, + execute::Execute, message::Message, parameter_status::ParameterStatus, + parse::Parse, password_message::PasswordMessage, query::Query, ready_for_query::{ReadyForQuery, TransactionStatus}, response::{Response, Severity}, row_description::{FieldDescription, FieldDescriptions, RowDescription}, startup_message::StartupMessage, + sync::Sync, terminate::Terminate, }; diff --git a/sqlx-postgres-protocol/src/message.rs b/sqlx-postgres-protocol/src/message.rs index f34dfac5..300fd64f 100644 --- a/sqlx-postgres-protocol/src/message.rs +++ b/sqlx-postgres-protocol/src/message.rs @@ -16,6 +16,8 @@ pub enum Message { RowDescription(RowDescription), DataRow(DataRow), Response(Box), + ParseComplete, + BindComplete, } impl Message { @@ -55,6 +57,8 @@ impl Message { b'T' => Message::RowDescription(RowDescription::decode(src)?), b'D' => Message::DataRow(DataRow::decode(src)?), b'C' => Message::CommandComplete(CommandComplete::decode(src)?), + b'1' => Message::ParseComplete, + b'2' => Message::BindComplete, _ => unimplemented!("decode not implemented for token: {}", token as char), })) diff --git a/sqlx-postgres-protocol/src/parse.rs b/sqlx-postgres-protocol/src/parse.rs new file mode 100644 index 00000000..ebf65ca8 --- /dev/null +++ b/sqlx-postgres-protocol/src/parse.rs @@ -0,0 +1,42 @@ +use crate::Encode; +use std::io; + +pub struct Parse<'a> { + portal: &'a str, + query: &'a str, + param_types: &'a [i32], +} + +impl<'a> Parse<'a> { + pub fn new(portal: &'a str, query: &'a str, param_types: &'a [i32]) -> Self { + Self { + portal, + query, + param_types, + } + } +} + +impl<'a> Encode for Parse<'a> { + fn encode(&self, buf: &mut Vec) -> io::Result<()> { + buf.push(b'P'); + + let len = 4 + self.portal.len() + 1 + self.query.len() + 1 + 2 + self.param_types.len() * 4; + + buf.extend_from_slice(&(len as i32).to_be_bytes()); + + buf.extend_from_slice(self.portal.as_bytes()); + buf.push(b'\0'); + + buf.extend_from_slice(self.query.as_bytes()); + buf.push(b'\0'); + + buf.extend_from_slice(&(self.param_types.len() as i16).to_be_bytes()); + + for param_type in self.param_types { + buf.extend_from_slice(¶m_type.to_be_bytes()); + } + + Ok(()) + } +} diff --git a/sqlx-postgres-protocol/src/sync.rs b/sqlx-postgres-protocol/src/sync.rs new file mode 100644 index 00000000..ebff82ae --- /dev/null +++ b/sqlx-postgres-protocol/src/sync.rs @@ -0,0 +1,13 @@ +use crate::Encode; +use std::io; + +pub struct Sync; + +impl Encode for Sync { + fn encode(&self, buf: &mut Vec) -> io::Result<()> { + buf.push(b'S'); + buf.extend_from_slice(&4_i32.to_be_bytes()); + + Ok(()) + } +} diff --git a/sqlx-postgres/src/connection/execute.rs b/sqlx-postgres/src/connection/execute.rs new file mode 100644 index 00000000..4483e93f --- /dev/null +++ b/sqlx-postgres/src/connection/execute.rs @@ -0,0 +1,291 @@ +use super::Connection; +use futures::{io::AsyncWrite, ready, Stream}; +use sqlx_postgres_protocol::{self as proto, Encode, Parse}; +use std::{ + future::Future, + io, + pin::Pin, + sync::atomic::Ordering, + task::{Context, Poll}, +}; + +// NOTE: This is a rough draft of the implementation + +#[inline] +pub fn execute<'a>(connection: &'a mut Connection, query: &'a str) -> Execute<'a> { + Execute { + connection, + query, + state: ExecuteState::Parse, + rows: 0, + } +} + +pub struct Execute<'a> { + connection: &'a mut Connection, + query: &'a str, + state: ExecuteState, + rows: u64, +} + +#[derive(Debug)] +enum ExecuteState { + Parse, + Bind, + Execute, + Sync, + SendingParse, + SendingBind, + SendingExecute, + SendingSync, + Flush, + WaitForComplete, +} + +impl<'a> Execute<'a> { + #[inline] + pub fn bind(self, value: &'a [u8]) -> Bind<'a, &'a [u8]> { + Bind { ex: self, value } + } +} + +fn poll_write_all( + mut writer: W, + buf: &mut Vec, + cx: &mut Context, +) -> Poll> { + // Derived from https://rust-lang-nursery.github.io/futures-api-docs/0.3.0-alpha.16/src/futures_util/io/write_all.rs.html#26 + while !buf.is_empty() { + let n = ready!(Pin::new(&mut writer).poll_write(cx, &*buf))?; + + buf.truncate(buf.len() - n); + + if n == 0 { + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); + } + } + + Poll::Ready(Ok(())) +} + +fn poll_execute( + cx: &mut Context, + conn: &mut Connection, + state: &mut ExecuteState, + query: &str, + values: &T, + out: &mut u64, +) -> Poll> { + loop { + *state = match state { + ExecuteState::Parse => { + conn.wbuf.clear(); + + let stmt = format!( + "__sqlx#{}", + conn.statement_index.fetch_add(1, Ordering::SeqCst) + ); + Parse::new(&stmt, query, &[]) + .encode(&mut conn.wbuf) + .unwrap(); + + ExecuteState::SendingParse + } + + ExecuteState::SendingParse => { + ready!(poll_write_all(&mut conn.stream.inner, &mut conn.wbuf, cx))?; + + ExecuteState::Bind + } + + ExecuteState::Bind => { + conn.wbuf.clear(); + + // FIXME: Think of a better way to build up a BIND message. Think on how to + // avoid allocation here. + + let mut values_buf = Vec::new(); + values_buf.extend_from_slice(&values.count().to_be_bytes()); + values.to_sql(&mut values_buf); + + // FIXME: We need to cache the statement name around + let stmt = format!("__sqlx#{}", conn.statement_index.load(Ordering::SeqCst) - 1); + + proto::Bind::new(&stmt, &stmt, &[], &values_buf, &[]) + .encode(&mut conn.wbuf) + .unwrap(); + + ExecuteState::SendingBind + } + + ExecuteState::SendingBind => { + ready!(poll_write_all(&mut conn.stream.inner, &mut conn.wbuf, cx))?; + + ExecuteState::Execute + } + + ExecuteState::Execute => { + conn.wbuf.clear(); + + // FIXME: We need to cache the statement name around + let stmt = format!("__sqlx#{}", conn.statement_index.load(Ordering::SeqCst) - 1); + + proto::Execute::new(&stmt, 0) + .encode(&mut conn.wbuf) + .unwrap(); + + ExecuteState::SendingExecute + } + + ExecuteState::SendingExecute => { + ready!(poll_write_all(&mut conn.stream.inner, &mut conn.wbuf, cx))?; + + ExecuteState::Sync + } + + ExecuteState::Sync => { + conn.wbuf.clear(); + proto::Sync.encode(&mut conn.wbuf).unwrap(); + + ExecuteState::SendingSync + } + + ExecuteState::SendingSync => { + ready!(poll_write_all(&mut conn.stream.inner, &mut conn.wbuf, cx))?; + + ExecuteState::Flush + } + + ExecuteState::Flush => { + ready!(Pin::new(&mut conn.stream.inner).poll_flush(cx))?; + + ExecuteState::WaitForComplete + } + + ExecuteState::WaitForComplete => { + while let Some(message) = ready!(Pin::new(&mut conn.stream).poll_next(cx)) { + match message? { + proto::Message::BindComplete | proto::Message::ParseComplete => { + // Indicates successful completion of a phase + } + + proto::Message::DataRow(_) => { + // This is EXECUTE so we are ignoring any potential results + } + + proto::Message::CommandComplete(body) => { + *out = body.rows(); + } + + proto::Message::ReadyForQuery(_) => { + // Successful completion of the whole cycle + return Poll::Ready(Ok(*out)); + } + + message => { + unimplemented!("received {:?} unimplemented message", message); + } + } + } + + // FIXME: This is technically reachable if the pg conn is dropped? + unreachable!() + } + } + } +} + +impl<'a> Future for Execute<'a> { + type Output = io::Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let self_ = self.get_mut(); + poll_execute( + cx, + &mut *self_.connection, + &mut self_.state, + &self_.query, + &(), + &mut self_.rows, + ) + } +} + +// TODO: This should be cleaned up and moved to core; probably needs to be generic over back-end +// TODO: I'm using some trait recursion here.. this should probably not be exposed in core +pub trait ToSql { + /// Converts the value of `self` into the appropriate format, appending it to `out`. + fn to_sql(&self, out: &mut Vec); + + // Count the number of value parameters recursively encoded. + fn count(&self) -> i16; +} + +impl<'a> ToSql for () { + #[inline] + fn to_sql(&self, _out: &mut Vec) { + // Do nothing + } + + #[inline] + fn count(&self) -> i16 { + 0 + } +} + +impl<'a> ToSql for &'a [u8] { + #[inline] + fn to_sql(&self, out: &mut Vec) { + out.extend_from_slice(&(self.len() as i32).to_be_bytes()); + out.extend_from_slice(self); + } + + #[inline] + fn count(&self) -> i16 { + 1 + } +} + +impl<'a, T: ToSql + 'a, U: ToSql + 'a> ToSql for (T, U) { + #[inline] + fn to_sql(&self, out: &mut Vec) { + self.0.to_sql(out); + self.1.to_sql(out); + } + + #[inline] + fn count(&self) -> i16 { + self.0.count() + self.1.count() + } +} + +pub struct Bind<'a, T: ToSql + 'a> { + ex: Execute<'a>, + value: T, +} + +impl<'a, T: ToSql + 'a> Bind<'a, T> { + #[inline] + pub fn bind(self, value: &'a [u8]) -> Bind<'a, (T, &'a [u8])> { + Bind { + ex: self.ex, + value: (self.value, value), + } + } +} + +impl<'a, T: Unpin + ToSql + 'a> Future for Bind<'a, T> { + type Output = io::Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let self_ = self.get_mut(); + poll_execute( + cx, + &mut *self_.ex.connection, + &mut self_.ex.state, + &self_.ex.query, + &self_.value, + &mut self_.ex.rows, + ) + } +} diff --git a/sqlx-postgres/src/connection/mod.rs b/sqlx-postgres/src/connection/mod.rs index 05c79463..dea4bc9c 100644 --- a/sqlx-postgres/src/connection/mod.rs +++ b/sqlx-postgres/src/connection/mod.rs @@ -1,19 +1,23 @@ use bytes::{BufMut, BytesMut}; use futures::{ io::{AsyncRead, AsyncWriteExt}, + ready, task::{Context, Poll}, Stream, }; use runtime::net::TcpStream; use sqlx_core::ConnectOptions; use sqlx_postgres_protocol::{Encode, Message, Terminate}; -use std::{fmt::Debug, io, pin::Pin}; +use std::{fmt::Debug, io, pin::Pin, sync::atomic::AtomicU64}; mod establish; -mod query; +mod execute; pub struct Connection { - stream: Framed, + pub(super) stream: Framed, + + // HACK: This is how we currently "name" queries when executing + statement_index: AtomicU64, // Buffer used when serializing outgoing messages // FIXME: Use BytesMut @@ -34,6 +38,7 @@ impl Connection { stream: Framed::new(stream), process_id: 0, secret_key: 0, + statement_index: AtomicU64::new(0), }; establish::establish(&mut conn, options).await?; @@ -41,8 +46,9 @@ impl Connection { Ok(conn) } - pub async fn execute<'a: 'b, 'b>(&'a mut self, query: &'b str) -> io::Result<()> { - query::query(self, query).await + #[inline] + pub fn execute<'a>(&'a mut self, query: &'a str) -> execute::Execute<'a> { + execute::execute(self, query) } pub async fn close(mut self) -> io::Result<()> { @@ -52,19 +58,14 @@ impl Connection { Ok(()) } - // Send client message to the server async fn send(&mut self, message: T) -> io::Result<()> where T: Encode + Debug, { self.wbuf.clear(); - log::trace!("send {:?}", message); - message.encode(&mut self.wbuf)?; - log::trace!("send buffer {:?}", bytes::Bytes::from(&*self.wbuf)); - self.stream.inner.write_all(&self.wbuf).await?; self.stream.inner.flush().await?; @@ -72,7 +73,7 @@ impl Connection { } } -struct Framed { +pub(super) struct Framed { inner: S, readable: bool, eof: bool, @@ -106,17 +107,7 @@ where } loop { - log::trace!("recv buffer {:?}", self_.buffer); - - let message = Message::decode(&mut self_.buffer)?; - - if log::log_enabled!(log::Level::Trace) { - if let Some(message) = &message { - log::trace!("recv {:?}", message); - } - } - - match message { + match Message::decode(&mut self_.buffer)? { Some(Message::ParameterStatus(_body)) => { // TODO: Not sure what to do with these but ignore } @@ -141,15 +132,9 @@ where let n = unsafe { let b = self_.buffer.bytes_mut(); - self_.inner.initializer().initialize(b); - let n = match Pin::new(&mut self_.inner).poll_read(cx, b)? { - Poll::Ready(cnt) => cnt, - Poll::Pending => { - return Poll::Pending; - } - }; + let n = ready!(Pin::new(&mut self_.inner).poll_read(cx, b))?; self_.buffer.advance_mut(n); diff --git a/sqlx-postgres/src/connection/query.rs b/sqlx-postgres/src/connection/query.rs deleted file mode 100644 index 455bd7fc..00000000 --- a/sqlx-postgres/src/connection/query.rs +++ /dev/null @@ -1,34 +0,0 @@ -use super::Connection; -use futures::StreamExt; -use sqlx_postgres_protocol::{Message, Query}; -use std::io; - -pub async fn query<'a: 'b, 'b>(conn: &'a mut Connection, query: &'b str) -> io::Result<()> { - conn.send(Query::new(query)).await?; - - while let Some(message) = conn.stream.next().await { - match message? { - Message::RowDescription(_) => { - // Do nothing - } - - Message::DataRow(_) => { - // Do nothing (for now) - } - - Message::ReadyForQuery(_) => { - break; - } - - Message::CommandComplete(_) => { - // Do nothing (for now) - } - - message => { - unimplemented!("received {:?} unimplemented message", message); - } - } - } - - Ok(()) -} diff --git a/src/main.rs b/src/main.rs index 9b9684b6..2d811dc5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,6 +3,10 @@ use sqlx::{pg::Connection, ConnectOptions}; use std::io; +// TODO: ToSql and FromSql (to [de]serialize values from/to Rust and SQL) +// TODO: Connection strings ala postgres@localhost/sqlx_dev +// TODO: Queries (currently we only support EXECUTE [drop results]) + #[runtime::main] async fn main() -> io::Result<()> { env_logger::init(); @@ -12,12 +16,16 @@ async fn main() -> io::Result<()> { .host("127.0.0.1") .port(5432) .user("postgres") - .database("postgres") - .password("password"), + .database("sqlx__dev"), ) .await?; - conn.execute("SELECT 1, 2, 3").await?; + conn.execute("INSERT INTO \"users\" (name) VALUES ($1)") + .bind(b"Joe") + .await?; + + let count = conn.execute("SELECT name FROM users").await?; + println!("users: {}", count); conn.close().await?;