Optimize encode of StartupMessage

This commit is contained in:
Ryan Leckey 2019-07-11 10:08:57 -07:00
parent da42be7d0a
commit f161fa3178
4 changed files with 57 additions and 109 deletions

View File

@ -6,7 +6,7 @@ use sqlx_postgres_protocol::{Encode, PasswordMessage, Response, Severity, Startu
fn criterion_benchmark(c: &mut Criterion) { fn criterion_benchmark(c: &mut Criterion) {
c.bench_function("encode Response::builder()", |b| { c.bench_function("encode Response::builder()", |b| {
let mut dst = Vec::new(); let mut dst = Vec::with_capacity(1024);
b.iter(|| { b.iter(|| {
dst.clear(); dst.clear();
Response::builder() Response::builder()
@ -22,7 +22,7 @@ fn criterion_benchmark(c: &mut Criterion) {
}); });
c.bench_function("encode PasswordMessage::cleartext", |b| { c.bench_function("encode PasswordMessage::cleartext", |b| {
let mut dst = Vec::new(); let mut dst = Vec::with_capacity(1024);
b.iter(|| { b.iter(|| {
dst.clear(); dst.clear();
PasswordMessage::cleartext("8e323AMF9YSE9zftFnuhQcvhz7Vf342W4cWU") PasswordMessage::cleartext("8e323AMF9YSE9zftFnuhQcvhz7Vf342W4cWU")
@ -32,31 +32,31 @@ fn criterion_benchmark(c: &mut Criterion) {
}); });
c.bench_function("encode StartupMessage", |b| { c.bench_function("encode StartupMessage", |b| {
let mut dst = Vec::new(); let mut dst = Vec::with_capacity(1024);
b.iter(|| { b.iter(|| {
dst.clear(); dst.clear();
StartupMessage::builder() StartupMessage::new(&[
.param("user", "postgres") ("user", "postgres"),
.param("database", "postgres") ("database", "postgres"),
.param("DateStyle", "ISO, MDY") ("DateStyle", "ISO, MDY"),
.param("IntervalStyle", "iso_8601") ("IntervalStyle", "iso_8601"),
.param("TimeZone", "UTC") ("TimeZone", "UTC"),
.param("extra_float_digits", "3") ("extra_float_digits", "3"),
.param("client_encoding", "UTF-8") ("client_encoding", "UTF-8"),
.build() ])
.encode(&mut dst) .encode(&mut dst)
.unwrap(); .unwrap();
}) })
}); });
c.bench_function("encode Password(MD5)", |b| { c.bench_function("encode Password(MD5)", |b| {
let mut dst = Vec::new(); let mut dst = Vec::with_capacity(1024);
b.iter(|| { b.iter(|| {
dst.clear(); dst.clear();
PasswordMessage::md5( PasswordMessage::md5(
"8e323AMF9YSE9zftFnuhQcvhz7Vf342W4cWU", "8e323AMF9YSE9zftFnuhQcvhz7Vf342W4cWU",
"postgres", "postgres",
&[10, 41, 20, 150], [10, 41, 20, 150],
) )
.encode(&mut dst) .encode(&mut dst)
.unwrap(); .unwrap();

View File

@ -1,91 +1,46 @@
use crate::Encode; use crate::Encode;
use bytes::{BufMut, Bytes, BytesMut}; use byteorder::{BigEndian, ByteOrder};
use std::io; use std::io;
#[derive(Debug)] #[derive(Debug)]
pub struct StartupMessage { pub struct StartupMessage<'a> {
// (major, minor) params: &'a [(&'a str, &'a str)],
version: (u16, u16),
params: Bytes,
} }
impl StartupMessage { impl<'a> StartupMessage<'a> {
pub fn builder() -> StartupMessageBuilder { pub fn new(params: &'a [(&'a str, &'a str)]) -> Self {
StartupMessageBuilder::new() Self { params }
} }
pub fn version(&self) -> (u16, u16) { pub fn params(&self) -> &'a [(&'a str, &'a str)] {
self.version self.params
}
pub fn params(&self) -> StartupMessageParams<'_> {
StartupMessageParams(&*self.params)
} }
} }
impl Encode for StartupMessage { impl<'a> Encode for StartupMessage<'a> {
fn encode(&self, buf: &mut Vec<u8>) -> io::Result<()> { fn encode(&self, buf: &mut Vec<u8>) -> io::Result<()> {
let len = self.params.len() + 8; let pos = buf.len();
buf.reserve(len); buf.extend_from_slice(&(0 as u32).to_be_bytes()); // skip over len
buf.put_u32_be(len as u32); buf.extend_from_slice(&3_u16.to_be_bytes()); // major version
buf.put_u16_be(self.version.0); buf.extend_from_slice(&0_u16.to_be_bytes()); // minor version
buf.put_u16_be(self.version.1);
buf.put(&self.params); for (name, value) in self.params {
buf.extend_from_slice(name.as_bytes());
buf.push(0);
buf.extend_from_slice(value.as_bytes());
buf.push(0);
}
buf.push(0);
// Write-back the len to the beginning of this frame
let len = buf.len() - pos;
BigEndian::write_u32(&mut buf[pos..], len as u32);
Ok(()) Ok(())
} }
} }
// TODO: Impl Iterator to iter over params
pub struct StartupMessageParams<'a>(&'a [u8]);
pub struct StartupMessageBuilder {
// (major, minor)
version: (u16, u16),
params: BytesMut,
}
impl Default for StartupMessageBuilder {
fn default() -> Self {
StartupMessageBuilder {
version: (3, 0),
params: BytesMut::with_capacity(156),
}
}
}
impl StartupMessageBuilder {
pub fn new() -> Self {
StartupMessageBuilder::default()
}
/// Set the protocol version number. Defaults to `3.0`.
pub fn version(mut self, major: u16, minor: u16) -> Self {
self.version = (major, minor);
self
}
pub fn param(mut self, name: &str, value: &str) -> Self {
self.params.reserve(name.len() + value.len() + 2);
self.params.put(name.as_bytes());
self.params.put_u8(0);
self.params.put(value.as_bytes());
self.params.put_u8(0);
self
}
pub fn build(mut self) -> StartupMessage {
self.params.reserve(1);
self.params.put_u8(0);
StartupMessage {
version: self.version,
params: self.params.freeze(),
}
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::StartupMessage; use super::StartupMessage;
@ -96,10 +51,7 @@ mod tests {
#[test] #[test]
fn it_encodes_startup_message() -> io::Result<()> { fn it_encodes_startup_message() -> io::Result<()> {
let message = StartupMessage::builder() let message = StartupMessage::new(&[("user", "postgres"), ("database", "postgres")]);
.param("user", "postgres")
.param("database", "postgres")
.build();
let mut buf = Vec::new(); let mut buf = Vec::new();
message.encode(&mut buf)?; message.encode(&mut buf)?;

View File

@ -10,31 +10,26 @@ pub async fn establish<'a, 'b: 'a>(
) -> io::Result<()> { ) -> io::Result<()> {
// See this doc for more runtime parameters // See this doc for more runtime parameters
// https://www.postgresql.org/docs/12/runtime-config-client.html // https://www.postgresql.org/docs/12/runtime-config-client.html
let mut message = StartupMessage::builder(); let params = &[
// FIXME: ConnectOptions user and database need to be required parameters and error
if let Some(user) = options.user { // before they get here
// FIXME: User is technically required. We should default this like psql does. ("user", options.user.expect("user is required")),
message = message.param("user", user); ("database", options.database.expect("database is required")),
}
if let Some(database) = options.database {
message = message.param("database", database);
}
let message = message
// Sets the display format for date and time values, // Sets the display format for date and time values,
// as well as the rules for interpreting ambiguous date input values. // as well as the rules for interpreting ambiguous date input values.
.param("DateStyle", "ISO, MDY") ("DateStyle", "ISO, MDY"),
// Sets the display format for interval values. // Sets the display format for interval values.
.param("IntervalStyle", "iso_8601") ("IntervalStyle", "iso_8601"),
// Sets the time zone for displaying and interpreting time stamps. // Sets the time zone for displaying and interpreting time stamps.
.param("TimeZone", "UTC") ("TimeZone", "UTC"),
// Adjust postgres to return percise values for floats // Adjust postgres to return percise values for floats
// NOTE: This is default in postgres 12+ // NOTE: This is default in postgres 12+
.param("extra_float_digits", "3") ("extra_float_digits", "3"),
// Sets the client-side encoding (character set). // Sets the client-side encoding (character set).
.param("client_encoding", "UTF-8") ("client_encoding", "UTF-8"),
.build(); ];
let message = StartupMessage::new(params);
conn.send(message).await?; conn.send(message).await?;

View File

@ -12,6 +12,7 @@ async fn main() -> io::Result<()> {
.host("127.0.0.1") .host("127.0.0.1")
.port(5432) .port(5432)
.user("postgres") .user("postgres")
.database("postgres")
.password("password"), .password("password"),
) )
.await?; .await?;