[postgres] Optimize code quality of encoding

This commit is contained in:
Ryan Leckey 2019-08-02 20:50:17 -07:00
parent 3496819e5b
commit ff3cc6a2eb
25 changed files with 553 additions and 557 deletions

View File

@ -1,86 +1,88 @@
#![feature(async_await)]
// #![feature(async_await)]
use sqlx::{postgres::Connection, ConnectOptions};
use std::io;
// use sqlx::{postgres::Connection, ConnectOptions};
// use std::io;
// TODO: ToSql and FromSql (to [de]serialize values from/to Rust and SQL)
// TODO: Connection strings ala postgres@localhost/sqlx_dev
// // TODO: ToSql and FromSql (to [de]serialize values from/to Rust and SQL)
// // TODO: Connection strings ala postgres@localhost/sqlx_dev
#[runtime::main(runtime_tokio::Tokio)]
async fn main() -> io::Result<()> {
env_logger::init();
// #[runtime::main(runtime_tokio::Tokio)]
// async fn main() -> io::Result<()> {
// env_logger::init();
// Connect as postgres / postgres and DROP the sqlx__dev database
// if exists and then re-create it
let mut conn = Connection::establish(
ConnectOptions::new()
.host("127.0.0.1")
.port(5432)
.user("postgres")
.database("postgres"),
)
.await?;
// // Connect as postgres / postgres and DROP the sqlx__dev database
// // if exists and then re-create it
// let mut conn = Connection::establish(
// ConnectOptions::new()
// .host("127.0.0.1")
// .port(5432)
// .user("postgres")
// .database("postgres"),
// )
// .await?;
println!(" :: drop database (if exists) sqlx__dev");
// println!(" :: drop database (if exists) sqlx__dev");
conn.prepare("DROP DATABASE IF EXISTS sqlx__dev")
.execute()
.await?;
// conn.prepare("DROP DATABASE IF EXISTS sqlx__dev")
// .execute()
// .await?;
println!(" :: create database sqlx__dev");
// println!(" :: create database sqlx__dev");
conn.prepare("CREATE DATABASE sqlx__dev").execute().await?;
// conn.prepare("CREATE DATABASE sqlx__dev").execute().await?;
conn.close().await?;
// conn.close().await?;
let mut conn = Connection::establish(
ConnectOptions::new()
.host("127.0.0.1")
.port(5432)
.user("postgres")
.database("sqlx__dev"),
)
.await?;
// let mut conn = Connection::establish(
// ConnectOptions::new()
// .host("127.0.0.1")
// .port(5432)
// .user("postgres")
// .database("sqlx__dev"),
// )
// .await?;
println!(" :: create schema");
// println!(" :: create schema");
conn.prepare(
r#"
CREATE TABLE IF NOT EXISTS users (
id BIGSERIAL PRIMARY KEY,
name TEXT NOT NULL
);
"#,
)
.execute()
.await?;
// conn.prepare(
// r#"
// CREATE TABLE IF NOT EXISTS users (
// id BIGSERIAL PRIMARY KEY,
// name TEXT NOT NULL
// );
// "#,
// )
// .execute()
// .await?;
println!(" :: insert");
// println!(" :: insert");
let new_row = conn
.prepare("INSERT INTO users (name) VALUES ($1) RETURNING id")
.bind(b"Joe")
.get()
.await?;
// let new_row = conn
// .prepare("INSERT INTO users (name) VALUES ($1) RETURNING id")
// .bind(b"Joe")
// .get()
// .await?;
let new_id = new_row.as_ref().unwrap().get(0);
// let new_id = new_row.as_ref().unwrap().get(0);
println!("insert {:?}", new_id);
// println!("insert {:?}", new_id);
// println!(" :: select");
// // println!(" :: select");
// conn.prepare("SELECT id FROM users")
// .select()
// .try_for_each(|row| {
// let id = row.get(0);
// // conn.prepare("SELECT id FROM users")
// // .select()
// // .try_for_each(|row| {
// // let id = row.get(0);
// println!("select {:?}", id);
// // println!("select {:?}", id);
// future::ok(())
// })
// .await?;
// // future::ok(())
// // })
// // .await?;
conn.close().await?;
// conn.close().await?;
Ok(())
}
// Ok(())
// }
fn main() {}

View File

@ -1,8 +1,6 @@
#![feature(non_exhaustive, async_await)]
#![cfg_attr(test, feature(test))]
#![allow(clippy::needless_lifetimes)]
// FIXME: Remove this once API has matured
#![allow(dead_code, unused_imports, unused_variables)]
@ -17,8 +15,8 @@ extern crate enum_tryfrom_derive;
mod options;
pub use self::options::ConnectOptions;
pub mod postgres;
pub mod mariadb;
pub mod postgres;
// Helper macro for writing long complex tests
#[macro_use]

View File

@ -30,9 +30,9 @@ pub async fn establish<'a, 'b: 'a>(
("client_encoding", "UTF-8"),
];
let message = StartupMessage::new(params);
let message = StartupMessage { params };
conn.send(message);
conn.write(message);
conn.flush().await?;
while let Some(message) = conn.receive().await? {
@ -44,19 +44,21 @@ pub async fn establish<'a, 'b: 'a>(
Message::Authentication(Authentication::CleartextPassword) => {
// FIXME: Should error early (before send) if the user did not supply a password
conn.send(PasswordMessage::cleartext(
conn.write(PasswordMessage::Cleartext(
options.password.unwrap_or_default(),
));
conn.flush().await?;
}
Message::Authentication(Authentication::Md5Password { salt }) => {
// FIXME: Should error early (before send) if the user did not supply a password
conn.send(PasswordMessage::md5(
options.password.unwrap_or_default(),
options.user.unwrap_or_default(),
conn.write(PasswordMessage::Md5 {
password: options.password.unwrap_or_default(),
user: options.user.unwrap_or_default(),
salt,
));
});
conn.flush().await?;
}

View File

@ -4,15 +4,15 @@ use std::io;
impl<'a> Prepare<'a> {
pub async fn execute(self) -> io::Result<u64> {
protocol::bind::trailer(
&mut self.connection.wbuf,
self.bind_state,
self.bind_values,
&[],
);
// protocol::bind::trailer(
// &mut self.connection.wbuf,
// self.bind_state,
// self.bind_values,
// &[],
// );
protocol::execute(&mut self.connection.wbuf, "", 0);
protocol::sync(&mut self.connection.wbuf);
// protocol::execute(&mut self.connection.wbuf, "", 0);
// protocol::sync(&mut self.connection.wbuf);
self.connection.flush().await?;

View File

@ -4,16 +4,16 @@ use std::io;
impl<'a> Prepare<'a> {
pub async fn get(self) -> io::Result<Option<DataRow>> {
protocol::bind::trailer(
&mut self.connection.wbuf,
self.bind_state,
self.bind_values,
&[],
);
// protocol::bind::trailer(
// &mut self.connection.wbuf,
// self.bind_state,
// self.bind_values,
// &[],
// );
protocol::execute(&mut self.connection.wbuf, "", 1);
protocol::close::portal(&mut self.connection.wbuf, "");
protocol::sync(&mut self.connection.wbuf);
// protocol::execute(&mut self.connection.wbuf, "", 1);
// protocol::close::portal(&mut self.connection.wbuf, "");
// protocol::sync(&mut self.connection.wbuf);
self.connection.flush().await?;
@ -21,7 +21,10 @@ impl<'a> Prepare<'a> {
while let Some(message) = self.connection.receive().await? {
match message {
Message::BindComplete | Message::ParseComplete | Message::PortalSuspended | Message::CloseComplete => {
Message::BindComplete
| Message::ParseComplete
| Message::PortalSuspended
| Message::CloseComplete => {
// Indicates successful completion of a phase
}

View File

@ -62,7 +62,7 @@ impl Connection {
}
pub async fn close(mut self) -> io::Result<()> {
self.send(Terminate);
self.write(Terminate);
self.flush().await?;
self.stream.close().await?;
@ -124,64 +124,14 @@ impl Connection {
}
}
fn send<T>(&mut self, message: T)
where
T: Encode + Debug,
{
log::trace!("encode {:?}", message);
// TODO: Encoding should not be fallible
message.encode(&mut self.wbuf).unwrap();
fn write(&mut self, message: impl Encode) {
message.encode(&mut self.wbuf);
}
async fn flush(&mut self) -> io::Result<()> {
// TODO: Find some other way to print a Vec<u8> as an ASCII escaped string
log::trace!("send {:?}", bytes::Bytes::from(&*self.wbuf));
WriteAllVec::new(&mut self.stream, &mut self.wbuf).await?;
self.stream.flush().await?;
self.stream.write_all(&self.wbuf).await?;
self.wbuf.clear();
Ok(())
}
}
// Derived from: https://rust-lang-nursery.github.io/futures-api-docs/0.3.0-alpha.16/src/futures_util/io/write_all.rs.html#10-13
// With alterations to be more efficient if we're writing from a mutable vector
// that we can erase
// TODO: Move to Core under 'sqlx_core::io' perhaps?
// TODO: Perhaps the futures project wants this?
pub struct WriteAllVec<'a, W: ?Sized + Unpin> {
writer: &'a mut W,
buf: &'a mut Vec<u8>,
}
impl<W: ?Sized + Unpin> Unpin for WriteAllVec<'_, W> {}
impl<'a, W: AsyncWrite + ?Sized + Unpin> WriteAllVec<'a, W> {
pub(super) fn new(writer: &'a mut W, buf: &'a mut Vec<u8>) -> Self {
WriteAllVec { writer, buf }
}
}
impl<W: AsyncWrite + ?Sized + Unpin> Future for WriteAllVec<'_, W> {
type Output = io::Result<()>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = &mut *self;
while !this.buf.is_empty() {
let n = ready!(Pin::new(&mut this.writer).poll_write(cx, this.buf))?;
this.buf.truncate(this.buf.len() - n);
if n == 0 {
return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
}
}
Poll::Ready(Ok(()))
}
}

View File

@ -1,39 +1,29 @@
use super::Connection;
use crate::postgres::protocol::{self, Parse};
use crate::{
postgres::protocol::{self, Parse},
types::ToSql,
};
pub struct Prepare<'a> {
pub(super) connection: &'a mut Connection,
pub(super) bind_state: (usize, usize),
pub(super) bind_values: usize,
}
#[inline]
pub fn prepare<'a, 'b>(connection: &'a mut Connection, query: &'b str) -> Prepare<'a> {
// TODO: Use a hash map to cache the parse
// TODO: Use named statements
connection.send(Parse::new("", query, &[]));
connection.write(Parse {
portal: "",
query,
param_types: &[],
});
let bind_state = protocol::bind::header(&mut connection.wbuf, "", "", &[]);
Prepare {
connection,
bind_state,
bind_values: 0,
}
Prepare { connection }
}
impl<'a> Prepare<'a> {
#[inline]
pub fn bind<'b>(mut self, value: &'b [u8]) -> Self {
protocol::bind::value(&mut self.connection.wbuf, value);
self.bind_values += 1;
self
}
#[inline]
pub fn bind_null<'b>(mut self) -> Self {
protocol::bind::value_null(&mut self.connection.wbuf);
self.bind_values += 1;
self
}
}
// impl<'a> Prepare<'a> {
// #[inline]
// pub fn bind<T>(mut self, value: impl ToSql<T>) -> Self {
// unimplemented!()
// }
// }

View File

@ -5,15 +5,15 @@ use std::io;
impl<'a> Prepare<'a> {
pub fn select(self) -> impl Stream<Item = Result<DataRow, io::Error>> + 'a + Unpin {
protocol::bind::trailer(
&mut self.connection.wbuf,
self.bind_state,
self.bind_values,
&[],
);
// protocol::bind::trailer(
// &mut self.connection.wbuf,
// self.bind_state,
// self.bind_values,
// &[],
// );
protocol::execute(&mut self.connection.wbuf, "", 0);
protocol::sync(&mut self.connection.wbuf);
// protocol::execute(&mut self.connection.wbuf, "", 0);
// protocol::sync(&mut self.connection.wbuf);
// FIXME: Manually implement Stream on a new type to avoid the unfold adapter
stream::unfold(self.connection, |conn| {

View File

@ -1,68 +1,46 @@
use super::{BufMut, Encode};
use byteorder::{BigEndian, ByteOrder};
pub fn header(buf: &mut Vec<u8>, portal: &str, statement: &str, formats: &[u16]) -> (usize, usize) {
buf.push(b'B');
pub struct Bind<'a> {
/// The name of the destination portal (an empty string selects the unnamed portal).
portal: &'a str,
// reserve room for the length
let len_pos = buf.len();
buf.extend_from_slice(&[0, 0, 0, 0]);
/// The name of the source prepared statement (an empty string selects the unnamed prepared statement).
statement: &'a str,
buf.extend_from_slice(portal.as_bytes());
buf.push(b'\0');
/// The parameter format codes. Each must presently be zero (text) or one (binary).
///
/// There can be zero to indicate that there are no parameters or that the parameters all use the
/// default format (text); or one, in which case the specified format code is applied to all
/// parameters; or it can equal the actual number of parameters.
formats: &'a [i16],
buf.extend_from_slice(statement.as_bytes());
buf.push(b'\0');
values: &'a [u8],
buf.extend_from_slice(&(formats.len() as i16).to_be_bytes());
for format in formats {
buf.extend_from_slice(&format.to_be_bytes());
}
// reserve room for the values count
let value_len_pos = buf.len();
buf.extend_from_slice(&[0, 0]);
(len_pos, value_len_pos)
/// The result-column format codes. Each must presently be zero (text) or one (binary).
///
/// There can be zero to indicate that there are no result columns or that the
/// result columns should all use the default format (text); or one, in which
/// case the specified format code is applied to all result columns (if any);
/// or it can equal the actual number of result columns of the query.
result_formats: &'a [i16],
}
pub fn value(buf: &mut Vec<u8>, value: &[u8]) {
buf.extend_from_slice(&(value.len() as u32).to_be_bytes());
buf.extend_from_slice(value);
}
impl Encode for Bind<'_> {
fn encode(&self, buf: &mut Vec<u8>) {
buf.put_byte(b'B');
pub fn value_null(buf: &mut Vec<u8>) {
buf.extend_from_slice(&(-1_i32).to_be_bytes());
}
let pos = buf.len();
buf.put_int_32(0); // skip over len
pub fn trailer(buf: &mut Vec<u8>, state: (usize, usize), values: usize, result_formats: &[i16]) {
buf.extend_from_slice(&(result_formats.len() as i16).to_be_bytes());
buf.put_str(self.portal);
buf.put_str(self.statement);
buf.put_array_int_16(&self.formats);
buf.put(self.values);
buf.put_array_int_16(&self.result_formats);
for format in result_formats {
buf.extend_from_slice(&format.to_be_bytes());
}
// Calculate and emplace the total len of the message
let len = buf.len() - state.0;
BigEndian::write_u32(&mut buf[(state.0)..], len as u32);
// Emplace the total num of values
BigEndian::write_u16(&mut buf[(state.1)..], values as u16);
}
#[cfg(test)]
mod tests {
const BIND: &[u8] = b"B\0\0\0\x16\0\0\0\0\0\x02\0\0\0\x011\0\0\0\x012\0\0";
#[test]
fn it_encodes_bind_for_two() {
let mut buf = Vec::new();
let state = super::header(&mut buf, "", "", &[]);
super::value(&mut buf, b"1");
super::value(&mut buf, b"2");
super::trailer(&mut buf, state, 2, &[]);
assert_eq!(buf, BIND);
// Write-back the len to the beginning of this frame
let len = buf.len() - pos;
BigEndian::write_i32(&mut buf[pos..], len as i32);
}
}

View File

@ -0,0 +1,22 @@
use super::{BufMut, Encode};
/// Sent instead of [`StartupMessage`] with a new connection to cancel a running query on an existing
/// connection.
///
/// https://www.postgresql.org/docs/devel/protocol-flow.html#id-1.10.5.7.9
pub struct CancelRequest {
/// The process ID of the target backend.
pub process_id: i32,
/// The secret key for the target backend.
pub secret_key: i32,
}
impl Encode for CancelRequest {
fn encode(&self, buf: &mut Vec<u8>) {
buf.put_int_32(16); // message length
buf.put_int_32(8087_7102); // constant for cancel request
buf.put_int_32(self.process_id);
buf.put_int_32(self.secret_key);
}
}

View File

@ -1,43 +1,61 @@
use super::{BufMut, Encode};
pub fn portal(buf: &mut Vec<u8>, name: &str) {
buf.push(b'C');
// TODO: Separate into two structs, ClosePortal and CloseStatement (?)
let len = 4 + name.len() + 2;
buf.extend_from_slice(&(len as i32).to_be_bytes());
buf.push(b'P');
buf.extend_from_slice(name.as_bytes());
buf.push(b'\0');
#[repr(u8)]
pub enum CloseKind {
PreparedStatement,
Portal,
}
pub fn statement(buf: &mut Vec<u8>, name: &str) {
buf.push(b'C');
pub struct Close<'a> {
kind: CloseKind,
let len = 4 + name.len() + 2;
buf.extend_from_slice(&(len as i32).to_be_bytes());
/// The name of the prepared statement or portal to close (an empty string selects the
/// unnamed prepared statement or portal).
name: &'a str,
}
buf.push(b'S');
buf.extend_from_slice(name.as_bytes());
buf.push(b'\0');
impl Encode for Close<'_> {
fn encode(&self, buf: &mut Vec<u8>) {
buf.put_byte(b'C');
// len + kind + nul + len(string)
buf.put_int_32((4 + 1 + 1 + self.name.len()) as i32);
buf.put_byte(match self.kind {
CloseKind::PreparedStatement => b'S',
CloseKind::Portal => b'P',
});
buf.put_str(self.name);
}
}
#[cfg(test)]
mod test {
use super::{BufMut, Close, CloseKind, Encode};
#[test]
fn it_encodes_close_portal() {
let mut buf = vec![];
super::portal(&mut buf, "ABC123");
let mut buf = Vec::new();
let m = Close {
kind: CloseKind::Portal,
name: "__sqlx_p_1",
};
assert_eq!(&buf, b"C\x00\x00\x00\x0fPABC123\x00");
m.encode(&mut buf);
assert_eq!(buf, b"C\0\0\0\x10P__sqlx_p_1\0");
}
#[test]
fn it_encodes_close_statement() {
let mut buf = vec![];
super::statement(&mut buf, "95 apples");
let mut buf = Vec::new();
let m = Close {
kind: CloseKind::PreparedStatement,
name: "__sqlx_s_1",
};
assert_eq!(&buf, b"C\x00\x00\x00\x12S95 apples\x00");
m.encode(&mut buf);
assert_eq!(buf, b"C\0\0\0\x10S__sqlx_s_1\0");
}
}

View File

@ -0,0 +1,27 @@
use super::{BufMut, Encode};
// TODO: Implement Decode and think on an optimal representation
/*
# Optimal for Encode
pub struct CopyData<'a> { data: &'a [u8] }
# Optimal for Decode
pub struct CopyData { data: Bytes }
# 1) Two structs (names?)
# 2) "Either" inner abstraction; removes ease of construction for Encode
*/
pub struct CopyData<'a> {
pub data: &'a [u8],
}
impl Encode for CopyData<'_> {
fn encode(&self, buf: &mut Vec<u8>) {
buf.put_byte(b'd');
// len + nul + len(string)
buf.put_int_32((4 + 1 + self.data.len()) as i32);
buf.put(&self.data);
}
}

View File

@ -0,0 +1,13 @@
use super::{BufMut, Encode};
// TODO: Implement Decode
pub struct CopyDone;
impl Encode for CopyDone {
#[inline]
fn encode(&self, buf: &mut Vec<u8>) {
buf.put_byte(b'c');
buf.put_int_32(4);
}
}

View File

@ -0,0 +1,14 @@
use super::{BufMut, Encode};
pub struct CopyFail<'a> {
pub error: &'a str,
}
impl Encode for CopyFail<'_> {
fn encode(&self, buf: &mut Vec<u8>) {
buf.put_byte(b'f');
// len + nul + len(string)
buf.put_int_32((4 + 1 + self.error.len()) as i32);
buf.put_str(&self.error);
}
}

View File

@ -1,54 +1,61 @@
/// The Describe message (portal variant) specifies the name of an existing portal
/// (or an empty string for the unnamed portal). The response is a RowDescription message
/// describing the rows that will be returned by executing the portal; or a NoData message
/// if the portal does not contain a query that will return rows; or ErrorResponse if there is no such portal.
pub fn portal(buf: &mut Vec<u8>, name: &str) {
buf.push(b'D');
use super::{BufMut, Encode};
let len = 4 + name.len() + 2;
buf.extend_from_slice(&(len as i32).to_be_bytes());
// TODO: Separate into two structs, DescribePortal and DescribeStatement (?)
buf.push(b'P');
buf.extend_from_slice(name.as_bytes());
buf.push(b'\0');
#[repr(u8)]
pub enum DescribeKind {
PreparedStatement,
Portal,
}
/// The Describe message (statement variant) specifies the name of an existing prepared statement
/// (or an empty string for the unnamed prepared statement). The response is a ParameterDescription
/// message describing the parameters needed by the statement, followed by a RowDescription message
/// describing the rows that will be returned when the statement is eventually executed
/// (or a NoData message if the statement will not return rows). ErrorResponse is issued if
/// there is no such prepared statement. Note that since Bind has not yet been issued,
/// the formats to be used for returned columns are not yet known to the backend; the
/// format code fields in the RowDescription message will be zeroes in this case.
pub fn statement(buf: &mut Vec<u8>, name: &str) {
buf.push(b'D');
pub struct Describe<'a> {
kind: DescribeKind,
let len = 4 + name.len() + 2;
buf.extend_from_slice(&(len as i32).to_be_bytes());
/// The name of the prepared statement or portal to describe (an empty string selects the
/// unnamed prepared statement or portal).
name: &'a str,
}
buf.push(b'S');
buf.extend_from_slice(name.as_bytes());
buf.push(b'\0');
impl Encode for Describe<'_> {
fn encode(&self, buf: &mut Vec<u8>) {
buf.put_byte(b'D');
// len + kind + nul + len(string)
buf.put_int_32((4 + 1 + 1 + self.name.len()) as i32);
buf.put_byte(match self.kind {
DescribeKind::PreparedStatement => b'S',
DescribeKind::Portal => b'P',
});
buf.put_str(self.name);
}
}
#[cfg(test)]
mod test {
use super::{BufMut, Describe, DescribeKind, Encode};
#[test]
fn it_encodes_describe_portal() {
let mut buf = vec![];
super::portal(&mut buf, "ABC123");
let mut buf = Vec::new();
let m = Describe {
kind: DescribeKind::Portal,
name: "__sqlx_p_1",
};
assert_eq!(&buf, b"D\x00\x00\x00\x0fPABC123\x00");
m.encode(&mut buf);
assert_eq!(buf, b"D\0\0\0\x10P__sqlx_p_1\0");
}
#[test]
fn it_encodes_describe_statement() {
let mut buf = vec![];
super::statement(&mut buf, "95 apples");
let mut buf = Vec::new();
let m = Describe {
kind: DescribeKind::PreparedStatement,
name: "__sqlx_s_1",
};
assert_eq!(&buf, b"D\x00\x00\x00\x12S95 apples\x00");
m.encode(&mut buf);
assert_eq!(buf, b"D\0\0\0\x10S__sqlx_s_1\0");
}
}

View File

@ -1,11 +1,69 @@
use std::io;
pub trait Encode {
// TODO: Remove
fn size_hint(&self) -> usize {
0
fn encode(&self, buf: &mut Vec<u8>);
}
pub trait BufMut {
fn put(&mut self, bytes: &[u8]);
fn put_byte(&mut self, value: u8);
fn put_int_16(&mut self, value: i16);
fn put_int_32(&mut self, value: i32);
fn put_array_int_16(&mut self, values: &[i16]);
fn put_array_int_32(&mut self, values: &[i32]);
fn put_str(&mut self, value: &str);
}
impl BufMut for Vec<u8> {
#[inline]
fn put(&mut self, bytes: &[u8]) {
self.extend_from_slice(bytes);
}
// FIXME: Use BytesMut and not Vec<u8> (also remove the error type here)
fn encode(&self, buf: &mut Vec<u8>) -> io::Result<()>;
#[inline]
fn put_byte(&mut self, value: u8) {
self.push(value);
}
#[inline]
fn put_int_16(&mut self, value: i16) {
self.extend_from_slice(&value.to_be_bytes());
}
#[inline]
fn put_int_32(&mut self, value: i32) {
self.extend_from_slice(&value.to_be_bytes());
}
#[inline]
fn put_str(&mut self, value: &str) {
self.extend_from_slice(value.as_bytes());
self.push(0);
}
#[inline]
fn put_array_int_16(&mut self, values: &[i16]) {
// FIXME: What happens here when len(values) > i16
self.put_int_16(values.len() as i16);
for value in values {
self.put_int_16(*value);
}
}
#[inline]
fn put_array_int_32(&mut self, values: &[i32]) {
// FIXME: What happens here when len(values) > i16
self.put_int_16(values.len() as i16);
for value in values {
self.put_int_32(*value);
}
}
}

View File

@ -1,28 +1,20 @@
/// Specifies the portal name (empty string denotes the unnamed portal) and a maximum
/// result-row count (zero meaning “fetch all rows”). The result-row count is only meaningful
/// for portals containing commands that return row sets; in other cases the command is
/// always executed to completion, and the row count is ignored.
pub fn execute(buf: &mut Vec<u8>, portal: &str, limit: i32) {
buf.push(b'E');
use super::{BufMut, Encode};
let len = 4 + portal.len() + 1 + 4;
buf.extend_from_slice(&(len as i32).to_be_bytes());
pub struct Execute<'a> {
/// The name of the portal to execute (an empty string selects the unnamed portal).
pub portal: &'a str,
// portal
buf.extend_from_slice(portal.as_bytes());
buf.push(b'\0');
// limit
buf.extend_from_slice(&limit.to_be_bytes());
/// Maximum number of rows to return, if portal contains a query
/// that returns rows (ignored otherwise). Zero denotes “no limit”.
pub limit: i32,
}
#[cfg(test)]
mod tests {
#[test]
fn it_encodes_execute() {
let mut buf = Vec::new();
super::execute(&mut buf, "", 0);
assert_eq!(&*buf, b"E\0\0\0\t\0\0\0\0\0");
impl Encode for Execute<'_> {
fn encode(&self, buf: &mut Vec<u8>) {
buf.put_byte(b'E');
// len + nul + len(string) + limit
buf.put_int_32((4 + 1 + self.portal.len() + 4) as i32);
buf.put_str(&self.portal);
buf.put_int_32(self.limit);
}
}

View File

@ -0,0 +1,11 @@
use super::{BufMut, Encode};
pub struct Flush;
impl Encode for Flush {
#[inline]
fn encode(&self, buf: &mut Vec<u8>) {
buf.put_byte(b'H');
buf.put_int_32(4);
}
}

View File

@ -1,49 +1,64 @@
mod bind;
mod cancel_request;
mod close;
mod copy_data;
mod copy_done;
mod copy_fail;
mod describe;
mod encode;
mod execute;
mod flush;
mod parse;
mod password_message;
mod query;
mod startup_message;
mod sync;
mod terminate;
// TODO: mod gss_enc_request;
// TODO: mod gss_response;
// TODO: mod sasl_initial_response;
// TODO: mod sasl_response;
// TODO: mod ssl_request;
pub use self::{
bind::Bind,
cancel_request::CancelRequest,
close::Close,
copy_data::CopyData,
copy_done::CopyDone,
copy_fail::CopyFail,
describe::Describe,
encode::{BufMut, Encode},
execute::Execute,
flush::Flush,
parse::Parse,
password_message::PasswordMessage,
query::Query,
startup_message::StartupMessage,
sync::Sync,
terminate::Terminate,
};
// TODO: Audit backend protocol
mod authentication;
mod backend_key_data;
mod command_complete;
mod data_row;
mod decode;
mod encode;
mod message;
mod notification_response;
mod parameter_description;
mod parameter_status;
mod parse;
mod password_message;
mod query;
mod ready_for_query;
mod response;
mod row_description;
mod startup_message;
mod terminate;
pub mod bind;
pub mod describe;
pub mod close;
mod execute;
mod sync;
pub use self::{execute::execute, sync::sync};
mod authentication;
pub use self::{
authentication::Authentication,
backend_key_data::BackendKeyData,
command_complete::CommandComplete,
data_row::DataRow,
decode::Decode,
encode::Encode,
message::Message,
notification_response::NotificationResponse,
parameter_description::ParameterDescription,
parameter_status::ParameterStatus,
parse::Parse,
password_message::PasswordMessage,
query::Query,
ready_for_query::{ReadyForQuery, TransactionStatus},
response::{Response, Severity},
row_description::{FieldDescription, FieldDescriptions, RowDescription},
startup_message::StartupMessage,
terminate::Terminate,
authentication::Authentication, backend_key_data::BackendKeyData,
command_complete::CommandComplete, data_row::DataRow, decode::Decode, message::Message,
notification_response::NotificationResponse, parameter_description::ParameterDescription,
parameter_status::ParameterStatus, ready_for_query::ReadyForQuery, response::Response,
row_description::RowDescription,
};

View File

@ -1,43 +1,22 @@
use super::Encode;
use std::io;
use super::{BufMut, Encode};
#[derive(Debug)]
pub struct Parse<'a> {
portal: &'a str,
query: &'a str,
param_types: &'a [i32],
pub portal: &'a str,
pub query: &'a str,
pub param_types: &'a [i32],
}
impl<'a> Parse<'a> {
pub fn new(portal: &'a str, query: &'a str, param_types: &'a [i32]) -> Self {
Self {
portal,
query,
param_types,
}
}
}
impl<'a> Encode for Parse<'a> {
fn encode(&self, buf: &mut Vec<u8>) -> io::Result<()> {
buf.push(b'P');
impl Encode for Parse<'_> {
fn encode(&self, buf: &mut Vec<u8>) {
buf.put_byte(b'P');
// len + portal + nul + query + null + len(param_types) + param_types
let len = 4 + self.portal.len() + 1 + self.query.len() + 1 + 2 + self.param_types.len() * 4;
buf.put_int_32(len as i32);
buf.extend_from_slice(&(len as i32).to_be_bytes());
buf.put_str(self.portal);
buf.put_str(self.query);
buf.extend_from_slice(self.portal.as_bytes());
buf.push(b'\0');
buf.extend_from_slice(self.query.as_bytes());
buf.push(b'\0');
buf.extend_from_slice(&(self.param_types.len() as i16).to_be_bytes());
for param_type in self.param_types {
buf.extend_from_slice(&param_type.to_be_bytes());
}
Ok(())
buf.put_array_int_32(&self.param_types);
}
}

View File

@ -1,60 +1,50 @@
use super::Encode;
use bytes::Bytes;
use super::{BufMut, Encode};
use md5::{Digest, Md5};
use std::io;
#[derive(Debug)]
pub struct PasswordMessage {
password: Bytes,
pub enum PasswordMessage<'a> {
Cleartext(&'a str),
Md5 {
password: &'a str,
user: &'a str,
salt: [u8; 4],
},
}
impl PasswordMessage {
/// Create a `PasswordMessage` with an unecrypted password.
pub fn cleartext(password: &str) -> Self {
Self {
password: Bytes::from(password),
impl Encode for PasswordMessage<'_> {
fn encode(&self, buf: &mut Vec<u8>) {
buf.put_byte(b'p');
match self {
PasswordMessage::Cleartext(s) => {
// len + password + nul
buf.put_int_32((4 + s.len() + 1) as i32);
buf.put_str(s);
}
PasswordMessage::Md5 {
password,
user,
salt,
} => {
let mut hasher = Md5::new();
hasher.input(password);
hasher.input(user);
let credentials = hex::encode(hasher.result_reset());
hasher.input(credentials);
hasher.input(salt);
let salted = hex::encode(hasher.result());
// len + "md5" + (salted)
buf.put_int_32((4 + 3 + salted.len()) as i32);
buf.put(b"md5");
buf.put(salted.as_bytes());
}
}
}
/// Create a `PasswordMessage` by hasing the password, user, and salt together using MD5.
pub fn md5(password: &str, user: &str, salt: [u8; 4]) -> Self {
let mut hasher = Md5::new();
hasher.input(password);
hasher.input(user);
let credentials = hex::encode(hasher.result_reset());
hasher.input(credentials);
hasher.input(salt);
let salted = hex::encode(hasher.result());
let mut password = Vec::with_capacity(3 + salted.len());
password.extend_from_slice(b"md5");
password.extend_from_slice(salted.as_bytes());
Self {
password: Bytes::from(password),
}
}
/// The password (encrypted, if requested).
pub fn password(&self) -> &[u8] {
&self.password
}
}
impl Encode for PasswordMessage {
fn size_hint(&self) -> usize {
self.password.len() + 5
}
fn encode(&self, buf: &mut Vec<u8>) -> io::Result<()> {
buf.push(b'p');
buf.extend_from_slice(&(self.password.len() + 4).to_be_bytes());
buf.extend_from_slice(&self.password);
Ok(())
}
}

View File

@ -1,44 +1,31 @@
use super::Encode;
use std::io;
use super::{BufMut, Encode};
#[derive(Debug)]
pub struct Query<'a>(&'a str);
impl<'a> Query<'a> {
#[inline]
pub fn new(query: &'a str) -> Self {
Self(query)
}
}
pub struct Query<'a>(pub &'a str);
impl Encode for Query<'_> {
fn encode(&self, buf: &mut Vec<u8>) -> io::Result<()> {
let len = self.0.len() + 4 + 1;
buf.push(b'Q');
buf.extend_from_slice(&(len as u32).to_be_bytes());
buf.extend_from_slice(self.0.as_bytes());
buf.push(0);
fn encode(&self, buf: &mut Vec<u8>) {
buf.put_byte(b'Q');
Ok(())
// len + query + nul
buf.put_int_32((4 + self.0.len() + 1) as i32);
buf.put_str(self.0);
}
}
#[cfg(test)]
mod tests {
use super::{Encode, Query};
use std::io;
use super::{BufMut, Encode, Query};
const QUERY_SELECT_1: &[u8] = b"Q\0\0\0\rSELECT 1\0";
#[test]
fn it_encodes_query() -> io::Result<()> {
let message = Query::new("SELECT 1");
fn it_encodes_query() {
let mut buf = Vec::new();
message.encode(&mut buf)?;
let m = Query("SELECT 1");
assert_eq!(&*buf, QUERY_SELECT_1);
m.encode(&mut buf);
Ok(())
assert_eq!(buf, QUERY_SELECT_1);
}
}

View File

@ -1,64 +1,46 @@
use super::Encode;
use super::{BufMut, Encode};
use byteorder::{BigEndian, ByteOrder};
use std::io;
#[derive(Debug)]
pub struct StartupMessage<'a> {
params: &'a [(&'a str, &'a str)],
pub params: &'a [(&'a str, &'a str)],
}
impl<'a> StartupMessage<'a> {
#[inline]
pub fn new(params: &'a [(&'a str, &'a str)]) -> Self {
Self { params }
}
#[inline]
pub fn params(&self) -> &'a [(&'a str, &'a str)] {
self.params
}
}
impl<'a> Encode for StartupMessage<'a> {
fn encode(&self, buf: &mut Vec<u8>) -> io::Result<()> {
impl Encode for StartupMessage<'_> {
fn encode(&self, buf: &mut Vec<u8>) {
let pos = buf.len();
buf.extend_from_slice(&(0 as u32).to_be_bytes()); // skip over len
buf.extend_from_slice(&3_u16.to_be_bytes()); // major version
buf.extend_from_slice(&0_u16.to_be_bytes()); // minor version
buf.put_int_32(0); // skip over len
// protocol version number (3.0)
buf.put_int_32(196608);
for (name, value) in self.params {
buf.extend_from_slice(name.as_bytes());
buf.push(0);
buf.extend_from_slice(value.as_bytes());
buf.push(0);
buf.put_str(name);
buf.put_str(value);
}
buf.push(0);
buf.put_byte(0);
// Write-back the len to the beginning of this frame
let len = buf.len() - pos;
BigEndian::write_u32(&mut buf[pos..], len as u32);
Ok(())
BigEndian::write_i32(&mut buf[pos..], len as i32);
}
}
#[cfg(test)]
mod tests {
use super::{Encode, StartupMessage};
use std::io;
use super::{BufMut, Encode, StartupMessage};
const STARTUP_MESSAGE: &[u8] = b"\0\0\0)\0\x03\0\0user\0postgres\0database\0postgres\0\0";
#[test]
fn it_encodes_startup_message() -> io::Result<()> {
let message = StartupMessage::new(&[("user", "postgres"), ("database", "postgres")]);
fn it_encodes_startup_message() {
let mut buf = Vec::new();
message.encode(&mut buf)?;
let m = StartupMessage {
params: &[("user", "postgres"), ("database", "postgres")],
};
assert_eq!(&*buf, STARTUP_MESSAGE);
m.encode(&mut buf);
Ok(())
assert_eq!(buf, STARTUP_MESSAGE);
}
}

View File

@ -1,30 +1,11 @@
/// This parameterless message causes the backend to close the current transaction if it's not inside
/// a BEGIN/COMMIT transaction block (“close” meaning to commit if no error, or roll back if error).
/// Then a ReadyForQuery response is issued.
pub fn sync(buf: &mut Vec<u8>) {
buf.push(b'S');
buf.extend_from_slice(&4_i32.to_be_bytes());
}
use super::{BufMut, Encode};
#[cfg(test)]
mod tests {
#[test]
fn it_encodes_sync() {
let mut buf = Vec::new();
super::sync(&mut buf);
pub struct Sync;
assert_eq!(&*buf, b"S\0\0\0\x04");
}
#[bench]
fn bench_encode_sync(b: &mut test::Bencher) {
let mut buf = Vec::new();
b.iter(|| {
for _ in 0..1000 {
buf.clear();
super::sync(&mut buf);
}
});
impl Encode for Sync {
#[inline]
fn encode(&self, buf: &mut Vec<u8>) {
buf.put_byte(b'S');
buf.put_int_32(4);
}
}

View File

@ -1,34 +1,11 @@
use super::Encode;
use std::io;
use super::{BufMut, Encode};
#[derive(Debug)]
pub struct Terminate;
impl Encode for Terminate {
fn encode(&self, buf: &mut Vec<u8>) -> io::Result<()> {
buf.push(b'X');
buf.extend_from_slice(&4_u32.to_be_bytes());
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::{Encode, Terminate};
use std::io;
const TERMINATE: &[u8] = b"X\0\0\0\x04";
#[test]
fn it_encodes_terminate() -> io::Result<()> {
let message = Terminate;
let mut buf = Vec::new();
message.encode(&mut buf)?;
assert_eq!(&*buf, TERMINATE);
Ok(())
#[inline]
fn encode(&self, buf: &mut Vec<u8>) {
buf.put_byte(b'X');
buf.put_int_32(4);
}
}