mirror of
https://github.com/launchbadge/sqlx.git
synced 2026-03-27 21:51:22 +00:00
Optimize encode of StartupMessage
This commit is contained in:
@@ -1,91 +1,46 @@
|
||||
use crate::Encode;
|
||||
use bytes::{BufMut, Bytes, BytesMut};
|
||||
use byteorder::{BigEndian, ByteOrder};
|
||||
use std::io;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct StartupMessage {
|
||||
// (major, minor)
|
||||
version: (u16, u16),
|
||||
params: Bytes,
|
||||
pub struct StartupMessage<'a> {
|
||||
params: &'a [(&'a str, &'a str)],
|
||||
}
|
||||
|
||||
impl StartupMessage {
|
||||
pub fn builder() -> StartupMessageBuilder {
|
||||
StartupMessageBuilder::new()
|
||||
impl<'a> StartupMessage<'a> {
|
||||
pub fn new(params: &'a [(&'a str, &'a str)]) -> Self {
|
||||
Self { params }
|
||||
}
|
||||
|
||||
pub fn version(&self) -> (u16, u16) {
|
||||
self.version
|
||||
}
|
||||
|
||||
pub fn params(&self) -> StartupMessageParams<'_> {
|
||||
StartupMessageParams(&*self.params)
|
||||
pub fn params(&self) -> &'a [(&'a str, &'a str)] {
|
||||
self.params
|
||||
}
|
||||
}
|
||||
|
||||
impl Encode for StartupMessage {
|
||||
impl<'a> Encode for StartupMessage<'a> {
|
||||
fn encode(&self, buf: &mut Vec<u8>) -> io::Result<()> {
|
||||
let len = self.params.len() + 8;
|
||||
buf.reserve(len);
|
||||
buf.put_u32_be(len as u32);
|
||||
buf.put_u16_be(self.version.0);
|
||||
buf.put_u16_be(self.version.1);
|
||||
buf.put(&self.params);
|
||||
let pos = buf.len();
|
||||
buf.extend_from_slice(&(0 as u32).to_be_bytes()); // skip over len
|
||||
buf.extend_from_slice(&3_u16.to_be_bytes()); // major version
|
||||
buf.extend_from_slice(&0_u16.to_be_bytes()); // minor version
|
||||
|
||||
for (name, value) in self.params {
|
||||
buf.extend_from_slice(name.as_bytes());
|
||||
buf.push(0);
|
||||
buf.extend_from_slice(value.as_bytes());
|
||||
buf.push(0);
|
||||
}
|
||||
|
||||
buf.push(0);
|
||||
|
||||
// Write-back the len to the beginning of this frame
|
||||
let len = buf.len() - pos;
|
||||
BigEndian::write_u32(&mut buf[pos..], len as u32);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Impl Iterator to iter over params
|
||||
pub struct StartupMessageParams<'a>(&'a [u8]);
|
||||
|
||||
pub struct StartupMessageBuilder {
|
||||
// (major, minor)
|
||||
version: (u16, u16),
|
||||
params: BytesMut,
|
||||
}
|
||||
|
||||
impl Default for StartupMessageBuilder {
|
||||
fn default() -> Self {
|
||||
StartupMessageBuilder {
|
||||
version: (3, 0),
|
||||
params: BytesMut::with_capacity(156),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl StartupMessageBuilder {
|
||||
pub fn new() -> Self {
|
||||
StartupMessageBuilder::default()
|
||||
}
|
||||
|
||||
/// Set the protocol version number. Defaults to `3.0`.
|
||||
pub fn version(mut self, major: u16, minor: u16) -> Self {
|
||||
self.version = (major, minor);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn param(mut self, name: &str, value: &str) -> Self {
|
||||
self.params.reserve(name.len() + value.len() + 2);
|
||||
self.params.put(name.as_bytes());
|
||||
self.params.put_u8(0);
|
||||
self.params.put(value.as_bytes());
|
||||
self.params.put_u8(0);
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(mut self) -> StartupMessage {
|
||||
self.params.reserve(1);
|
||||
self.params.put_u8(0);
|
||||
|
||||
StartupMessage {
|
||||
version: self.version,
|
||||
params: self.params.freeze(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::StartupMessage;
|
||||
@@ -96,10 +51,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn it_encodes_startup_message() -> io::Result<()> {
|
||||
let message = StartupMessage::builder()
|
||||
.param("user", "postgres")
|
||||
.param("database", "postgres")
|
||||
.build();
|
||||
let message = StartupMessage::new(&[("user", "postgres"), ("database", "postgres")]);
|
||||
|
||||
let mut buf = Vec::new();
|
||||
message.encode(&mut buf)?;
|
||||
|
||||
Reference in New Issue
Block a user