diff --git a/examples/postgres.rs b/examples/postgres.rs index 4d95d4e5..7eccecd4 100644 --- a/examples/postgres.rs +++ b/examples/postgres.rs @@ -1,86 +1,88 @@ -#![feature(async_await)] +// #![feature(async_await)] -use sqlx::{postgres::Connection, ConnectOptions}; -use std::io; +// use sqlx::{postgres::Connection, ConnectOptions}; +// use std::io; -// TODO: ToSql and FromSql (to [de]serialize values from/to Rust and SQL) -// TODO: Connection strings ala postgres@localhost/sqlx_dev +// // TODO: ToSql and FromSql (to [de]serialize values from/to Rust and SQL) +// // TODO: Connection strings ala postgres@localhost/sqlx_dev -#[runtime::main(runtime_tokio::Tokio)] -async fn main() -> io::Result<()> { - env_logger::init(); +// #[runtime::main(runtime_tokio::Tokio)] +// async fn main() -> io::Result<()> { +// env_logger::init(); - // Connect as postgres / postgres and DROP the sqlx__dev database - // if exists and then re-create it - let mut conn = Connection::establish( - ConnectOptions::new() - .host("127.0.0.1") - .port(5432) - .user("postgres") - .database("postgres"), - ) - .await?; +// // Connect as postgres / postgres and DROP the sqlx__dev database +// // if exists and then re-create it +// let mut conn = Connection::establish( +// ConnectOptions::new() +// .host("127.0.0.1") +// .port(5432) +// .user("postgres") +// .database("postgres"), +// ) +// .await?; - println!(" :: drop database (if exists) sqlx__dev"); +// println!(" :: drop database (if exists) sqlx__dev"); - conn.prepare("DROP DATABASE IF EXISTS sqlx__dev") - .execute() - .await?; +// conn.prepare("DROP DATABASE IF EXISTS sqlx__dev") +// .execute() +// .await?; - println!(" :: create database sqlx__dev"); +// println!(" :: create database sqlx__dev"); - conn.prepare("CREATE DATABASE sqlx__dev").execute().await?; +// conn.prepare("CREATE DATABASE sqlx__dev").execute().await?; - conn.close().await?; +// conn.close().await?; - let mut conn = Connection::establish( - ConnectOptions::new() - .host("127.0.0.1") - .port(5432) - .user("postgres") - .database("sqlx__dev"), - ) - .await?; +// let mut conn = Connection::establish( +// ConnectOptions::new() +// .host("127.0.0.1") +// .port(5432) +// .user("postgres") +// .database("sqlx__dev"), +// ) +// .await?; - println!(" :: create schema"); +// println!(" :: create schema"); - conn.prepare( - r#" -CREATE TABLE IF NOT EXISTS users ( - id BIGSERIAL PRIMARY KEY, - name TEXT NOT NULL -); - "#, - ) - .execute() - .await?; +// conn.prepare( +// r#" +// CREATE TABLE IF NOT EXISTS users ( +// id BIGSERIAL PRIMARY KEY, +// name TEXT NOT NULL +// ); +// "#, +// ) +// .execute() +// .await?; - println!(" :: insert"); +// println!(" :: insert"); - let new_row = conn - .prepare("INSERT INTO users (name) VALUES ($1) RETURNING id") - .bind(b"Joe") - .get() - .await?; +// let new_row = conn +// .prepare("INSERT INTO users (name) VALUES ($1) RETURNING id") +// .bind(b"Joe") +// .get() +// .await?; - let new_id = new_row.as_ref().unwrap().get(0); +// let new_id = new_row.as_ref().unwrap().get(0); - println!("insert {:?}", new_id); +// println!("insert {:?}", new_id); - // println!(" :: select"); +// // println!(" :: select"); - // conn.prepare("SELECT id FROM users") - // .select() - // .try_for_each(|row| { - // let id = row.get(0); +// // conn.prepare("SELECT id FROM users") +// // .select() +// // .try_for_each(|row| { +// // let id = row.get(0); - // println!("select {:?}", id); +// // println!("select {:?}", id); - // future::ok(()) - // }) - // .await?; +// // future::ok(()) +// // }) +// // .await?; - conn.close().await?; +// conn.close().await?; - Ok(()) -} +// Ok(()) +// } + +fn main() {} diff --git a/src/lib.rs b/src/lib.rs index 9b4cb911..b4196b64 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,8 +1,6 @@ #![feature(non_exhaustive, async_await)] #![cfg_attr(test, feature(test))] - #![allow(clippy::needless_lifetimes)] - // FIXME: Remove this once API has matured #![allow(dead_code, unused_imports, unused_variables)] @@ -17,8 +15,8 @@ extern crate enum_tryfrom_derive; mod options; pub use self::options::ConnectOptions; -pub mod postgres; pub mod mariadb; +pub mod postgres; // Helper macro for writing long complex tests #[macro_use] diff --git a/src/postgres/connection/establish.rs b/src/postgres/connection/establish.rs index 77bec8de..c8509ca7 100644 --- a/src/postgres/connection/establish.rs +++ b/src/postgres/connection/establish.rs @@ -30,9 +30,9 @@ pub async fn establish<'a, 'b: 'a>( ("client_encoding", "UTF-8"), ]; - let message = StartupMessage::new(params); + let message = StartupMessage { params }; - conn.send(message); + conn.write(message); conn.flush().await?; while let Some(message) = conn.receive().await? { @@ -44,19 +44,21 @@ pub async fn establish<'a, 'b: 'a>( Message::Authentication(Authentication::CleartextPassword) => { // FIXME: Should error early (before send) if the user did not supply a password - conn.send(PasswordMessage::cleartext( + conn.write(PasswordMessage::Cleartext( options.password.unwrap_or_default(), )); + conn.flush().await?; } Message::Authentication(Authentication::Md5Password { salt }) => { // FIXME: Should error early (before send) if the user did not supply a password - conn.send(PasswordMessage::md5( - options.password.unwrap_or_default(), - options.user.unwrap_or_default(), + conn.write(PasswordMessage::Md5 { + password: options.password.unwrap_or_default(), + user: options.user.unwrap_or_default(), salt, - )); + }); + conn.flush().await?; } diff --git a/src/postgres/connection/execute.rs b/src/postgres/connection/execute.rs index 6f0eac39..7ec7faa0 100644 --- a/src/postgres/connection/execute.rs +++ b/src/postgres/connection/execute.rs @@ -4,15 +4,15 @@ use std::io; impl<'a> Prepare<'a> { pub async fn execute(self) -> io::Result { - protocol::bind::trailer( - &mut self.connection.wbuf, - self.bind_state, - self.bind_values, - &[], - ); + // protocol::bind::trailer( + // &mut self.connection.wbuf, + // self.bind_state, + // self.bind_values, + // &[], + // ); - protocol::execute(&mut self.connection.wbuf, "", 0); - protocol::sync(&mut self.connection.wbuf); + // protocol::execute(&mut self.connection.wbuf, "", 0); + // protocol::sync(&mut self.connection.wbuf); self.connection.flush().await?; diff --git a/src/postgres/connection/get.rs b/src/postgres/connection/get.rs index f5238804..2f4eec6b 100644 --- a/src/postgres/connection/get.rs +++ b/src/postgres/connection/get.rs @@ -4,16 +4,16 @@ use std::io; impl<'a> Prepare<'a> { pub async fn get(self) -> io::Result> { - protocol::bind::trailer( - &mut self.connection.wbuf, - self.bind_state, - self.bind_values, - &[], - ); + // protocol::bind::trailer( + // &mut self.connection.wbuf, + // self.bind_state, + // self.bind_values, + // &[], + // ); - protocol::execute(&mut self.connection.wbuf, "", 1); - protocol::close::portal(&mut self.connection.wbuf, ""); - protocol::sync(&mut self.connection.wbuf); + // protocol::execute(&mut self.connection.wbuf, "", 1); + // protocol::close::portal(&mut self.connection.wbuf, ""); + // protocol::sync(&mut self.connection.wbuf); self.connection.flush().await?; @@ -21,7 +21,10 @@ impl<'a> Prepare<'a> { while let Some(message) = self.connection.receive().await? { match message { - Message::BindComplete | Message::ParseComplete | Message::PortalSuspended | Message::CloseComplete => { + Message::BindComplete + | Message::ParseComplete + | Message::PortalSuspended + | Message::CloseComplete => { // Indicates successful completion of a phase } diff --git a/src/postgres/connection/mod.rs b/src/postgres/connection/mod.rs index 7dbb04e5..c3f87a29 100644 --- a/src/postgres/connection/mod.rs +++ b/src/postgres/connection/mod.rs @@ -62,7 +62,7 @@ impl Connection { } pub async fn close(mut self) -> io::Result<()> { - self.send(Terminate); + self.write(Terminate); self.flush().await?; self.stream.close().await?; @@ -124,64 +124,14 @@ impl Connection { } } - fn send(&mut self, message: T) - where - T: Encode + Debug, - { - log::trace!("encode {:?}", message); - - // TODO: Encoding should not be fallible - message.encode(&mut self.wbuf).unwrap(); + fn write(&mut self, message: impl Encode) { + message.encode(&mut self.wbuf); } async fn flush(&mut self) -> io::Result<()> { - // TODO: Find some other way to print a Vec as an ASCII escaped string - log::trace!("send {:?}", bytes::Bytes::from(&*self.wbuf)); - - WriteAllVec::new(&mut self.stream, &mut self.wbuf).await?; - - self.stream.flush().await?; + self.stream.write_all(&self.wbuf).await?; + self.wbuf.clear(); Ok(()) } } - -// Derived from: https://rust-lang-nursery.github.io/futures-api-docs/0.3.0-alpha.16/src/futures_util/io/write_all.rs.html#10-13 -// With alterations to be more efficient if we're writing from a mutable vector -// that we can erase - -// TODO: Move to Core under 'sqlx_core::io' perhaps? -// TODO: Perhaps the futures project wants this? - -pub struct WriteAllVec<'a, W: ?Sized + Unpin> { - writer: &'a mut W, - buf: &'a mut Vec, -} - -impl Unpin for WriteAllVec<'_, W> {} - -impl<'a, W: AsyncWrite + ?Sized + Unpin> WriteAllVec<'a, W> { - pub(super) fn new(writer: &'a mut W, buf: &'a mut Vec) -> Self { - WriteAllVec { writer, buf } - } -} - -impl Future for WriteAllVec<'_, W> { - type Output = io::Result<()>; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = &mut *self; - - while !this.buf.is_empty() { - let n = ready!(Pin::new(&mut this.writer).poll_write(cx, this.buf))?; - - this.buf.truncate(this.buf.len() - n); - - if n == 0 { - return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); - } - } - - Poll::Ready(Ok(())) - } -} diff --git a/src/postgres/connection/prepare.rs b/src/postgres/connection/prepare.rs index 58e32111..1ca223fe 100644 --- a/src/postgres/connection/prepare.rs +++ b/src/postgres/connection/prepare.rs @@ -1,39 +1,29 @@ use super::Connection; -use crate::postgres::protocol::{self, Parse}; +use crate::{ + postgres::protocol::{self, Parse}, + types::ToSql, +}; pub struct Prepare<'a> { pub(super) connection: &'a mut Connection, - pub(super) bind_state: (usize, usize), - pub(super) bind_values: usize, } #[inline] pub fn prepare<'a, 'b>(connection: &'a mut Connection, query: &'b str) -> Prepare<'a> { // TODO: Use a hash map to cache the parse // TODO: Use named statements - connection.send(Parse::new("", query, &[])); + connection.write(Parse { + portal: "", + query, + param_types: &[], + }); - let bind_state = protocol::bind::header(&mut connection.wbuf, "", "", &[]); - - Prepare { - connection, - bind_state, - bind_values: 0, - } + Prepare { connection } } -impl<'a> Prepare<'a> { - #[inline] - pub fn bind<'b>(mut self, value: &'b [u8]) -> Self { - protocol::bind::value(&mut self.connection.wbuf, value); - self.bind_values += 1; - self - } - - #[inline] - pub fn bind_null<'b>(mut self) -> Self { - protocol::bind::value_null(&mut self.connection.wbuf); - self.bind_values += 1; - self - } -} +// impl<'a> Prepare<'a> { +// #[inline] +// pub fn bind(mut self, value: impl ToSql) -> Self { +// unimplemented!() +// } +// } diff --git a/src/postgres/connection/select.rs b/src/postgres/connection/select.rs index 9e41d2a9..94f4814f 100644 --- a/src/postgres/connection/select.rs +++ b/src/postgres/connection/select.rs @@ -5,15 +5,15 @@ use std::io; impl<'a> Prepare<'a> { pub fn select(self) -> impl Stream> + 'a + Unpin { - protocol::bind::trailer( - &mut self.connection.wbuf, - self.bind_state, - self.bind_values, - &[], - ); + // protocol::bind::trailer( + // &mut self.connection.wbuf, + // self.bind_state, + // self.bind_values, + // &[], + // ); - protocol::execute(&mut self.connection.wbuf, "", 0); - protocol::sync(&mut self.connection.wbuf); + // protocol::execute(&mut self.connection.wbuf, "", 0); + // protocol::sync(&mut self.connection.wbuf); // FIXME: Manually implement Stream on a new type to avoid the unfold adapter stream::unfold(self.connection, |conn| { diff --git a/src/postgres/protocol/bind.rs b/src/postgres/protocol/bind.rs index 24c3c81e..6094d269 100644 --- a/src/postgres/protocol/bind.rs +++ b/src/postgres/protocol/bind.rs @@ -1,68 +1,46 @@ +use super::{BufMut, Encode}; use byteorder::{BigEndian, ByteOrder}; -pub fn header(buf: &mut Vec, portal: &str, statement: &str, formats: &[u16]) -> (usize, usize) { - buf.push(b'B'); +pub struct Bind<'a> { + /// The name of the destination portal (an empty string selects the unnamed portal). + portal: &'a str, - // reserve room for the length - let len_pos = buf.len(); - buf.extend_from_slice(&[0, 0, 0, 0]); + /// The name of the source prepared statement (an empty string selects the unnamed prepared statement). + statement: &'a str, - buf.extend_from_slice(portal.as_bytes()); - buf.push(b'\0'); + /// The parameter format codes. Each must presently be zero (text) or one (binary). + /// + /// There can be zero to indicate that there are no parameters or that the parameters all use the + /// default format (text); or one, in which case the specified format code is applied to all + /// parameters; or it can equal the actual number of parameters. + formats: &'a [i16], - buf.extend_from_slice(statement.as_bytes()); - buf.push(b'\0'); + values: &'a [u8], - buf.extend_from_slice(&(formats.len() as i16).to_be_bytes()); - - for format in formats { - buf.extend_from_slice(&format.to_be_bytes()); - } - - // reserve room for the values count - let value_len_pos = buf.len(); - buf.extend_from_slice(&[0, 0]); - - (len_pos, value_len_pos) + /// The result-column format codes. Each must presently be zero (text) or one (binary). + /// + /// There can be zero to indicate that there are no result columns or that the + /// result columns should all use the default format (text); or one, in which + /// case the specified format code is applied to all result columns (if any); + /// or it can equal the actual number of result columns of the query. + result_formats: &'a [i16], } -pub fn value(buf: &mut Vec, value: &[u8]) { - buf.extend_from_slice(&(value.len() as u32).to_be_bytes()); - buf.extend_from_slice(value); -} +impl Encode for Bind<'_> { + fn encode(&self, buf: &mut Vec) { + buf.put_byte(b'B'); -pub fn value_null(buf: &mut Vec) { - buf.extend_from_slice(&(-1_i32).to_be_bytes()); -} + let pos = buf.len(); + buf.put_int_32(0); // skip over len -pub fn trailer(buf: &mut Vec, state: (usize, usize), values: usize, result_formats: &[i16]) { - buf.extend_from_slice(&(result_formats.len() as i16).to_be_bytes()); + buf.put_str(self.portal); + buf.put_str(self.statement); + buf.put_array_int_16(&self.formats); + buf.put(self.values); + buf.put_array_int_16(&self.result_formats); - for format in result_formats { - buf.extend_from_slice(&format.to_be_bytes()); - } - - // Calculate and emplace the total len of the message - let len = buf.len() - state.0; - BigEndian::write_u32(&mut buf[(state.0)..], len as u32); - - // Emplace the total num of values - BigEndian::write_u16(&mut buf[(state.1)..], values as u16); -} - -#[cfg(test)] -mod tests { - const BIND: &[u8] = b"B\0\0\0\x16\0\0\0\0\0\x02\0\0\0\x011\0\0\0\x012\0\0"; - - #[test] - fn it_encodes_bind_for_two() { - let mut buf = Vec::new(); - - let state = super::header(&mut buf, "", "", &[]); - super::value(&mut buf, b"1"); - super::value(&mut buf, b"2"); - super::trailer(&mut buf, state, 2, &[]); - - assert_eq!(buf, BIND); + // Write-back the len to the beginning of this frame + let len = buf.len() - pos; + BigEndian::write_i32(&mut buf[pos..], len as i32); } } diff --git a/src/postgres/protocol/cancel_request.rs b/src/postgres/protocol/cancel_request.rs new file mode 100644 index 00000000..9f8624ee --- /dev/null +++ b/src/postgres/protocol/cancel_request.rs @@ -0,0 +1,22 @@ +use super::{BufMut, Encode}; + +/// Sent instead of [`StartupMessage`] with a new connection to cancel a running query on an existing +/// connection. +/// +/// https://www.postgresql.org/docs/devel/protocol-flow.html#id-1.10.5.7.9 +pub struct CancelRequest { + /// The process ID of the target backend. + pub process_id: i32, + + /// The secret key for the target backend. + pub secret_key: i32, +} + +impl Encode for CancelRequest { + fn encode(&self, buf: &mut Vec) { + buf.put_int_32(16); // message length + buf.put_int_32(8087_7102); // constant for cancel request + buf.put_int_32(self.process_id); + buf.put_int_32(self.secret_key); + } +} diff --git a/src/postgres/protocol/close.rs b/src/postgres/protocol/close.rs index e298f5c8..280cb23f 100644 --- a/src/postgres/protocol/close.rs +++ b/src/postgres/protocol/close.rs @@ -1,43 +1,61 @@ +use super::{BufMut, Encode}; -pub fn portal(buf: &mut Vec, name: &str) { - buf.push(b'C'); +// TODO: Separate into two structs, ClosePortal and CloseStatement (?) - let len = 4 + name.len() + 2; - buf.extend_from_slice(&(len as i32).to_be_bytes()); - - buf.push(b'P'); - - buf.extend_from_slice(name.as_bytes()); - buf.push(b'\0'); +#[repr(u8)] +pub enum CloseKind { + PreparedStatement, + Portal, } -pub fn statement(buf: &mut Vec, name: &str) { - buf.push(b'C'); +pub struct Close<'a> { + kind: CloseKind, - let len = 4 + name.len() + 2; - buf.extend_from_slice(&(len as i32).to_be_bytes()); + /// The name of the prepared statement or portal to close (an empty string selects the + /// unnamed prepared statement or portal). + name: &'a str, +} - buf.push(b'S'); - - buf.extend_from_slice(name.as_bytes()); - buf.push(b'\0'); +impl Encode for Close<'_> { + fn encode(&self, buf: &mut Vec) { + buf.put_byte(b'C'); + // len + kind + nul + len(string) + buf.put_int_32((4 + 1 + 1 + self.name.len()) as i32); + buf.put_byte(match self.kind { + CloseKind::PreparedStatement => b'S', + CloseKind::Portal => b'P', + }); + buf.put_str(self.name); + } } #[cfg(test)] mod test { + use super::{BufMut, Close, CloseKind, Encode}; + #[test] fn it_encodes_close_portal() { - let mut buf = vec![]; - super::portal(&mut buf, "ABC123"); + let mut buf = Vec::new(); + let m = Close { + kind: CloseKind::Portal, + name: "__sqlx_p_1", + }; - assert_eq!(&buf, b"C\x00\x00\x00\x0fPABC123\x00"); + m.encode(&mut buf); + + assert_eq!(buf, b"C\0\0\0\x10P__sqlx_p_1\0"); } #[test] fn it_encodes_close_statement() { - let mut buf = vec![]; - super::statement(&mut buf, "95 apples"); + let mut buf = Vec::new(); + let m = Close { + kind: CloseKind::PreparedStatement, + name: "__sqlx_s_1", + }; - assert_eq!(&buf, b"C\x00\x00\x00\x12S95 apples\x00"); + m.encode(&mut buf); + + assert_eq!(buf, b"C\0\0\0\x10S__sqlx_s_1\0"); } } diff --git a/src/postgres/protocol/copy_data.rs b/src/postgres/protocol/copy_data.rs new file mode 100644 index 00000000..9f0bea4d --- /dev/null +++ b/src/postgres/protocol/copy_data.rs @@ -0,0 +1,27 @@ +use super::{BufMut, Encode}; + +// TODO: Implement Decode and think on an optimal representation + +/* +# Optimal for Encode +pub struct CopyData<'a> { data: &'a [u8] } + +# Optimal for Decode +pub struct CopyData { data: Bytes } + +# 1) Two structs (names?) +# 2) "Either" inner abstraction; removes ease of construction for Encode +*/ + +pub struct CopyData<'a> { + pub data: &'a [u8], +} + +impl Encode for CopyData<'_> { + fn encode(&self, buf: &mut Vec) { + buf.put_byte(b'd'); + // len + nul + len(string) + buf.put_int_32((4 + 1 + self.data.len()) as i32); + buf.put(&self.data); + } +} diff --git a/src/postgres/protocol/copy_done.rs b/src/postgres/protocol/copy_done.rs new file mode 100644 index 00000000..9b89c82d --- /dev/null +++ b/src/postgres/protocol/copy_done.rs @@ -0,0 +1,13 @@ +use super::{BufMut, Encode}; + +// TODO: Implement Decode + +pub struct CopyDone; + +impl Encode for CopyDone { + #[inline] + fn encode(&self, buf: &mut Vec) { + buf.put_byte(b'c'); + buf.put_int_32(4); + } +} diff --git a/src/postgres/protocol/copy_fail.rs b/src/postgres/protocol/copy_fail.rs new file mode 100644 index 00000000..87b2b995 --- /dev/null +++ b/src/postgres/protocol/copy_fail.rs @@ -0,0 +1,14 @@ +use super::{BufMut, Encode}; + +pub struct CopyFail<'a> { + pub error: &'a str, +} + +impl Encode for CopyFail<'_> { + fn encode(&self, buf: &mut Vec) { + buf.put_byte(b'f'); + // len + nul + len(string) + buf.put_int_32((4 + 1 + self.error.len()) as i32); + buf.put_str(&self.error); + } +} diff --git a/src/postgres/protocol/describe.rs b/src/postgres/protocol/describe.rs index 246f6a8d..2638db2e 100644 --- a/src/postgres/protocol/describe.rs +++ b/src/postgres/protocol/describe.rs @@ -1,54 +1,61 @@ -/// The Describe message (portal variant) specifies the name of an existing portal -/// (or an empty string for the unnamed portal). The response is a RowDescription message -/// describing the rows that will be returned by executing the portal; or a NoData message -/// if the portal does not contain a query that will return rows; or ErrorResponse if there is no such portal. -pub fn portal(buf: &mut Vec, name: &str) { - buf.push(b'D'); +use super::{BufMut, Encode}; - let len = 4 + name.len() + 2; - buf.extend_from_slice(&(len as i32).to_be_bytes()); +// TODO: Separate into two structs, DescribePortal and DescribeStatement (?) - buf.push(b'P'); - - buf.extend_from_slice(name.as_bytes()); - buf.push(b'\0'); +#[repr(u8)] +pub enum DescribeKind { + PreparedStatement, + Portal, } -/// The Describe message (statement variant) specifies the name of an existing prepared statement -/// (or an empty string for the unnamed prepared statement). The response is a ParameterDescription -/// message describing the parameters needed by the statement, followed by a RowDescription message -/// describing the rows that will be returned when the statement is eventually executed -/// (or a NoData message if the statement will not return rows). ErrorResponse is issued if -/// there is no such prepared statement. Note that since Bind has not yet been issued, -/// the formats to be used for returned columns are not yet known to the backend; the -/// format code fields in the RowDescription message will be zeroes in this case. -pub fn statement(buf: &mut Vec, name: &str) { - buf.push(b'D'); +pub struct Describe<'a> { + kind: DescribeKind, - let len = 4 + name.len() + 2; - buf.extend_from_slice(&(len as i32).to_be_bytes()); + /// The name of the prepared statement or portal to describe (an empty string selects the + /// unnamed prepared statement or portal). + name: &'a str, +} - buf.push(b'S'); - - buf.extend_from_slice(name.as_bytes()); - buf.push(b'\0'); +impl Encode for Describe<'_> { + fn encode(&self, buf: &mut Vec) { + buf.put_byte(b'D'); + // len + kind + nul + len(string) + buf.put_int_32((4 + 1 + 1 + self.name.len()) as i32); + buf.put_byte(match self.kind { + DescribeKind::PreparedStatement => b'S', + DescribeKind::Portal => b'P', + }); + buf.put_str(self.name); + } } #[cfg(test)] mod test { + use super::{BufMut, Describe, DescribeKind, Encode}; + #[test] fn it_encodes_describe_portal() { - let mut buf = vec![]; - super::portal(&mut buf, "ABC123"); + let mut buf = Vec::new(); + let m = Describe { + kind: DescribeKind::Portal, + name: "__sqlx_p_1", + }; - assert_eq!(&buf, b"D\x00\x00\x00\x0fPABC123\x00"); + m.encode(&mut buf); + + assert_eq!(buf, b"D\0\0\0\x10P__sqlx_p_1\0"); } #[test] fn it_encodes_describe_statement() { - let mut buf = vec![]; - super::statement(&mut buf, "95 apples"); + let mut buf = Vec::new(); + let m = Describe { + kind: DescribeKind::PreparedStatement, + name: "__sqlx_s_1", + }; - assert_eq!(&buf, b"D\x00\x00\x00\x12S95 apples\x00"); + m.encode(&mut buf); + + assert_eq!(buf, b"D\0\0\0\x10S__sqlx_s_1\0"); } } diff --git a/src/postgres/protocol/encode.rs b/src/postgres/protocol/encode.rs index 8ad8251f..de3a5b74 100644 --- a/src/postgres/protocol/encode.rs +++ b/src/postgres/protocol/encode.rs @@ -1,11 +1,69 @@ use std::io; pub trait Encode { - // TODO: Remove - fn size_hint(&self) -> usize { - 0 + fn encode(&self, buf: &mut Vec); +} + +pub trait BufMut { + fn put(&mut self, bytes: &[u8]); + + fn put_byte(&mut self, value: u8); + + fn put_int_16(&mut self, value: i16); + + fn put_int_32(&mut self, value: i32); + + fn put_array_int_16(&mut self, values: &[i16]); + + fn put_array_int_32(&mut self, values: &[i32]); + + fn put_str(&mut self, value: &str); +} + +impl BufMut for Vec { + #[inline] + fn put(&mut self, bytes: &[u8]) { + self.extend_from_slice(bytes); } - // FIXME: Use BytesMut and not Vec (also remove the error type here) - fn encode(&self, buf: &mut Vec) -> io::Result<()>; + #[inline] + fn put_byte(&mut self, value: u8) { + self.push(value); + } + + #[inline] + fn put_int_16(&mut self, value: i16) { + self.extend_from_slice(&value.to_be_bytes()); + } + + #[inline] + fn put_int_32(&mut self, value: i32) { + self.extend_from_slice(&value.to_be_bytes()); + } + + #[inline] + fn put_str(&mut self, value: &str) { + self.extend_from_slice(value.as_bytes()); + self.push(0); + } + + #[inline] + fn put_array_int_16(&mut self, values: &[i16]) { + // FIXME: What happens here when len(values) > i16 + self.put_int_16(values.len() as i16); + + for value in values { + self.put_int_16(*value); + } + } + + #[inline] + fn put_array_int_32(&mut self, values: &[i32]) { + // FIXME: What happens here when len(values) > i16 + self.put_int_16(values.len() as i16); + + for value in values { + self.put_int_32(*value); + } + } } diff --git a/src/postgres/protocol/execute.rs b/src/postgres/protocol/execute.rs index cf03f98c..94004489 100644 --- a/src/postgres/protocol/execute.rs +++ b/src/postgres/protocol/execute.rs @@ -1,28 +1,20 @@ -/// Specifies the portal name (empty string denotes the unnamed portal) and a maximum -/// result-row count (zero meaning “fetch all rows”). The result-row count is only meaningful -/// for portals containing commands that return row sets; in other cases the command is -/// always executed to completion, and the row count is ignored. -pub fn execute(buf: &mut Vec, portal: &str, limit: i32) { - buf.push(b'E'); +use super::{BufMut, Encode}; - let len = 4 + portal.len() + 1 + 4; - buf.extend_from_slice(&(len as i32).to_be_bytes()); +pub struct Execute<'a> { + /// The name of the portal to execute (an empty string selects the unnamed portal). + pub portal: &'a str, - // portal - buf.extend_from_slice(portal.as_bytes()); - buf.push(b'\0'); - - // limit - buf.extend_from_slice(&limit.to_be_bytes()); + /// Maximum number of rows to return, if portal contains a query + /// that returns rows (ignored otherwise). Zero denotes “no limit”. + pub limit: i32, } -#[cfg(test)] -mod tests { - #[test] - fn it_encodes_execute() { - let mut buf = Vec::new(); - super::execute(&mut buf, "", 0); - - assert_eq!(&*buf, b"E\0\0\0\t\0\0\0\0\0"); +impl Encode for Execute<'_> { + fn encode(&self, buf: &mut Vec) { + buf.put_byte(b'E'); + // len + nul + len(string) + limit + buf.put_int_32((4 + 1 + self.portal.len() + 4) as i32); + buf.put_str(&self.portal); + buf.put_int_32(self.limit); } } diff --git a/src/postgres/protocol/flush.rs b/src/postgres/protocol/flush.rs new file mode 100644 index 00000000..88e8f594 --- /dev/null +++ b/src/postgres/protocol/flush.rs @@ -0,0 +1,11 @@ +use super::{BufMut, Encode}; + +pub struct Flush; + +impl Encode for Flush { + #[inline] + fn encode(&self, buf: &mut Vec) { + buf.put_byte(b'H'); + buf.put_int_32(4); + } +} diff --git a/src/postgres/protocol/mod.rs b/src/postgres/protocol/mod.rs index ba2c3ebf..f5f4dfa0 100644 --- a/src/postgres/protocol/mod.rs +++ b/src/postgres/protocol/mod.rs @@ -1,49 +1,64 @@ +mod bind; +mod cancel_request; +mod close; +mod copy_data; +mod copy_done; +mod copy_fail; +mod describe; +mod encode; +mod execute; +mod flush; +mod parse; +mod password_message; +mod query; +mod startup_message; +mod sync; +mod terminate; + +// TODO: mod gss_enc_request; +// TODO: mod gss_response; +// TODO: mod sasl_initial_response; +// TODO: mod sasl_response; +// TODO: mod ssl_request; + +pub use self::{ + bind::Bind, + cancel_request::CancelRequest, + close::Close, + copy_data::CopyData, + copy_done::CopyDone, + copy_fail::CopyFail, + describe::Describe, + encode::{BufMut, Encode}, + execute::Execute, + flush::Flush, + parse::Parse, + password_message::PasswordMessage, + query::Query, + startup_message::StartupMessage, + sync::Sync, + terminate::Terminate, +}; + +// TODO: Audit backend protocol + +mod authentication; mod backend_key_data; mod command_complete; mod data_row; mod decode; -mod encode; mod message; mod notification_response; mod parameter_description; mod parameter_status; -mod parse; -mod password_message; -mod query; mod ready_for_query; mod response; mod row_description; -mod startup_message; -mod terminate; - -pub mod bind; -pub mod describe; -pub mod close; - -mod execute; -mod sync; - -pub use self::{execute::execute, sync::sync}; - -mod authentication; pub use self::{ - authentication::Authentication, - backend_key_data::BackendKeyData, - command_complete::CommandComplete, - data_row::DataRow, - decode::Decode, - encode::Encode, - message::Message, - notification_response::NotificationResponse, - parameter_description::ParameterDescription, - parameter_status::ParameterStatus, - parse::Parse, - password_message::PasswordMessage, - query::Query, - ready_for_query::{ReadyForQuery, TransactionStatus}, - response::{Response, Severity}, - row_description::{FieldDescription, FieldDescriptions, RowDescription}, - startup_message::StartupMessage, - terminate::Terminate, + authentication::Authentication, backend_key_data::BackendKeyData, + command_complete::CommandComplete, data_row::DataRow, decode::Decode, message::Message, + notification_response::NotificationResponse, parameter_description::ParameterDescription, + parameter_status::ParameterStatus, ready_for_query::ReadyForQuery, response::Response, + row_description::RowDescription, }; diff --git a/src/postgres/protocol/parse.rs b/src/postgres/protocol/parse.rs index 360d46f4..42553382 100644 --- a/src/postgres/protocol/parse.rs +++ b/src/postgres/protocol/parse.rs @@ -1,43 +1,22 @@ -use super::Encode; -use std::io; +use super::{BufMut, Encode}; -#[derive(Debug)] pub struct Parse<'a> { - portal: &'a str, - query: &'a str, - param_types: &'a [i32], + pub portal: &'a str, + pub query: &'a str, + pub param_types: &'a [i32], } -impl<'a> Parse<'a> { - pub fn new(portal: &'a str, query: &'a str, param_types: &'a [i32]) -> Self { - Self { - portal, - query, - param_types, - } - } -} - -impl<'a> Encode for Parse<'a> { - fn encode(&self, buf: &mut Vec) -> io::Result<()> { - buf.push(b'P'); +impl Encode for Parse<'_> { + fn encode(&self, buf: &mut Vec) { + buf.put_byte(b'P'); + // len + portal + nul + query + null + len(param_types) + param_types let len = 4 + self.portal.len() + 1 + self.query.len() + 1 + 2 + self.param_types.len() * 4; + buf.put_int_32(len as i32); - buf.extend_from_slice(&(len as i32).to_be_bytes()); + buf.put_str(self.portal); + buf.put_str(self.query); - buf.extend_from_slice(self.portal.as_bytes()); - buf.push(b'\0'); - - buf.extend_from_slice(self.query.as_bytes()); - buf.push(b'\0'); - - buf.extend_from_slice(&(self.param_types.len() as i16).to_be_bytes()); - - for param_type in self.param_types { - buf.extend_from_slice(¶m_type.to_be_bytes()); - } - - Ok(()) + buf.put_array_int_32(&self.param_types); } } diff --git a/src/postgres/protocol/password_message.rs b/src/postgres/protocol/password_message.rs index 69797257..8da7e782 100644 --- a/src/postgres/protocol/password_message.rs +++ b/src/postgres/protocol/password_message.rs @@ -1,60 +1,50 @@ -use super::Encode; -use bytes::Bytes; +use super::{BufMut, Encode}; use md5::{Digest, Md5}; -use std::io; #[derive(Debug)] -pub struct PasswordMessage { - password: Bytes, +pub enum PasswordMessage<'a> { + Cleartext(&'a str), + Md5 { + password: &'a str, + user: &'a str, + salt: [u8; 4], + }, } -impl PasswordMessage { - /// Create a `PasswordMessage` with an unecrypted password. - pub fn cleartext(password: &str) -> Self { - Self { - password: Bytes::from(password), +impl Encode for PasswordMessage<'_> { + fn encode(&self, buf: &mut Vec) { + buf.put_byte(b'p'); + + match self { + PasswordMessage::Cleartext(s) => { + // len + password + nul + buf.put_int_32((4 + s.len() + 1) as i32); + buf.put_str(s); + } + + PasswordMessage::Md5 { + password, + user, + salt, + } => { + let mut hasher = Md5::new(); + + hasher.input(password); + hasher.input(user); + + let credentials = hex::encode(hasher.result_reset()); + + hasher.input(credentials); + hasher.input(salt); + + let salted = hex::encode(hasher.result()); + + // len + "md5" + (salted) + buf.put_int_32((4 + 3 + salted.len()) as i32); + + buf.put(b"md5"); + buf.put(salted.as_bytes()); + } } } - - /// Create a `PasswordMessage` by hasing the password, user, and salt together using MD5. - pub fn md5(password: &str, user: &str, salt: [u8; 4]) -> Self { - let mut hasher = Md5::new(); - - hasher.input(password); - hasher.input(user); - - let credentials = hex::encode(hasher.result_reset()); - - hasher.input(credentials); - hasher.input(salt); - - let salted = hex::encode(hasher.result()); - - let mut password = Vec::with_capacity(3 + salted.len()); - password.extend_from_slice(b"md5"); - password.extend_from_slice(salted.as_bytes()); - - Self { - password: Bytes::from(password), - } - } - - /// The password (encrypted, if requested). - pub fn password(&self) -> &[u8] { - &self.password - } -} - -impl Encode for PasswordMessage { - fn size_hint(&self) -> usize { - self.password.len() + 5 - } - - fn encode(&self, buf: &mut Vec) -> io::Result<()> { - buf.push(b'p'); - buf.extend_from_slice(&(self.password.len() + 4).to_be_bytes()); - buf.extend_from_slice(&self.password); - - Ok(()) - } } diff --git a/src/postgres/protocol/query.rs b/src/postgres/protocol/query.rs index fc8218cb..c48a6905 100644 --- a/src/postgres/protocol/query.rs +++ b/src/postgres/protocol/query.rs @@ -1,44 +1,31 @@ -use super::Encode; -use std::io; +use super::{BufMut, Encode}; -#[derive(Debug)] -pub struct Query<'a>(&'a str); - -impl<'a> Query<'a> { - #[inline] - pub fn new(query: &'a str) -> Self { - Self(query) - } -} +pub struct Query<'a>(pub &'a str); impl Encode for Query<'_> { - fn encode(&self, buf: &mut Vec) -> io::Result<()> { - let len = self.0.len() + 4 + 1; - buf.push(b'Q'); - buf.extend_from_slice(&(len as u32).to_be_bytes()); - buf.extend_from_slice(self.0.as_bytes()); - buf.push(0); + fn encode(&self, buf: &mut Vec) { + buf.put_byte(b'Q'); - Ok(()) + // len + query + nul + buf.put_int_32((4 + self.0.len() + 1) as i32); + + buf.put_str(self.0); } } #[cfg(test)] mod tests { - use super::{Encode, Query}; - use std::io; + use super::{BufMut, Encode, Query}; const QUERY_SELECT_1: &[u8] = b"Q\0\0\0\rSELECT 1\0"; #[test] - fn it_encodes_query() -> io::Result<()> { - let message = Query::new("SELECT 1"); - + fn it_encodes_query() { let mut buf = Vec::new(); - message.encode(&mut buf)?; + let m = Query("SELECT 1"); - assert_eq!(&*buf, QUERY_SELECT_1); + m.encode(&mut buf); - Ok(()) + assert_eq!(buf, QUERY_SELECT_1); } } diff --git a/src/postgres/protocol/startup_message.rs b/src/postgres/protocol/startup_message.rs index 93e26b63..2cb06657 100644 --- a/src/postgres/protocol/startup_message.rs +++ b/src/postgres/protocol/startup_message.rs @@ -1,64 +1,46 @@ -use super::Encode; +use super::{BufMut, Encode}; use byteorder::{BigEndian, ByteOrder}; -use std::io; -#[derive(Debug)] pub struct StartupMessage<'a> { - params: &'a [(&'a str, &'a str)], + pub params: &'a [(&'a str, &'a str)], } -impl<'a> StartupMessage<'a> { - #[inline] - pub fn new(params: &'a [(&'a str, &'a str)]) -> Self { - Self { params } - } - - #[inline] - pub fn params(&self) -> &'a [(&'a str, &'a str)] { - self.params - } -} - -impl<'a> Encode for StartupMessage<'a> { - fn encode(&self, buf: &mut Vec) -> io::Result<()> { +impl Encode for StartupMessage<'_> { + fn encode(&self, buf: &mut Vec) { let pos = buf.len(); - buf.extend_from_slice(&(0 as u32).to_be_bytes()); // skip over len - buf.extend_from_slice(&3_u16.to_be_bytes()); // major version - buf.extend_from_slice(&0_u16.to_be_bytes()); // minor version + buf.put_int_32(0); // skip over len + + // protocol version number (3.0) + buf.put_int_32(196608); for (name, value) in self.params { - buf.extend_from_slice(name.as_bytes()); - buf.push(0); - buf.extend_from_slice(value.as_bytes()); - buf.push(0); + buf.put_str(name); + buf.put_str(value); } - buf.push(0); + buf.put_byte(0); // Write-back the len to the beginning of this frame let len = buf.len() - pos; - BigEndian::write_u32(&mut buf[pos..], len as u32); - - Ok(()) + BigEndian::write_i32(&mut buf[pos..], len as i32); } } #[cfg(test)] mod tests { - use super::{Encode, StartupMessage}; - use std::io; + use super::{BufMut, Encode, StartupMessage}; const STARTUP_MESSAGE: &[u8] = b"\0\0\0)\0\x03\0\0user\0postgres\0database\0postgres\0\0"; #[test] - fn it_encodes_startup_message() -> io::Result<()> { - let message = StartupMessage::new(&[("user", "postgres"), ("database", "postgres")]); - + fn it_encodes_startup_message() { let mut buf = Vec::new(); - message.encode(&mut buf)?; + let m = StartupMessage { + params: &[("user", "postgres"), ("database", "postgres")], + }; - assert_eq!(&*buf, STARTUP_MESSAGE); + m.encode(&mut buf); - Ok(()) + assert_eq!(buf, STARTUP_MESSAGE); } } diff --git a/src/postgres/protocol/sync.rs b/src/postgres/protocol/sync.rs index 03a53318..51566fc3 100644 --- a/src/postgres/protocol/sync.rs +++ b/src/postgres/protocol/sync.rs @@ -1,30 +1,11 @@ -/// This parameterless message causes the backend to close the current transaction if it's not inside -/// a BEGIN/COMMIT transaction block (“close” meaning to commit if no error, or roll back if error). -/// Then a ReadyForQuery response is issued. -pub fn sync(buf: &mut Vec) { - buf.push(b'S'); - buf.extend_from_slice(&4_i32.to_be_bytes()); -} +use super::{BufMut, Encode}; -#[cfg(test)] -mod tests { - #[test] - fn it_encodes_sync() { - let mut buf = Vec::new(); - super::sync(&mut buf); +pub struct Sync; - assert_eq!(&*buf, b"S\0\0\0\x04"); - } - - #[bench] - fn bench_encode_sync(b: &mut test::Bencher) { - let mut buf = Vec::new(); - - b.iter(|| { - for _ in 0..1000 { - buf.clear(); - super::sync(&mut buf); - } - }); +impl Encode for Sync { + #[inline] + fn encode(&self, buf: &mut Vec) { + buf.put_byte(b'S'); + buf.put_int_32(4); } } diff --git a/src/postgres/protocol/terminate.rs b/src/postgres/protocol/terminate.rs index c6c87b14..a0402815 100644 --- a/src/postgres/protocol/terminate.rs +++ b/src/postgres/protocol/terminate.rs @@ -1,34 +1,11 @@ -use super::Encode; -use std::io; +use super::{BufMut, Encode}; -#[derive(Debug)] pub struct Terminate; impl Encode for Terminate { - fn encode(&self, buf: &mut Vec) -> io::Result<()> { - buf.push(b'X'); - buf.extend_from_slice(&4_u32.to_be_bytes()); - - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::{Encode, Terminate}; - use std::io; - - const TERMINATE: &[u8] = b"X\0\0\0\x04"; - - #[test] - fn it_encodes_terminate() -> io::Result<()> { - let message = Terminate; - - let mut buf = Vec::new(); - message.encode(&mut buf)?; - - assert_eq!(&*buf, TERMINATE); - - Ok(()) + #[inline] + fn encode(&self, buf: &mut Vec) { + buf.put_byte(b'X'); + buf.put_int_32(4); } }