feat: implement read_packet for postgres

This commit is contained in:
Daniel Akhterov 2021-01-23 13:27:52 -08:00
parent d5053d1b1d
commit c8f7601ad1
No known key found for this signature in database
GPG Key ID: 80408CD2586A5A52
9 changed files with 444 additions and 85 deletions

18
Cargo.lock generated
View File

@ -150,6 +150,15 @@ version = "4.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e91831deabf0d6d7ec49552e489aed63b7456a7a3c46cff62adad428110b0af0"
[[package]]
name = "atoi"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "616896e05fc0e2649463a93a15183c6a16bf03413a7af88ef1285ddedfa9cda5"
dependencies = [
"num-traits",
]
[[package]]
name = "atomic-waker"
version = "1.0.0"
@ -631,6 +640,12 @@ version = "0.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ffc5c5338469d4d3ea17d269fa8ea3512ad247247c30bd2df69e68309ed0a08"
[[package]]
name = "md5"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771"
[[package]]
name = "memchr"
version = "2.3.4"
@ -1110,8 +1125,10 @@ name = "sqlx-postgres"
version = "0.6.0-pre"
dependencies = [
"anyhow",
"atoi",
"base64",
"bitflags",
"byteorder",
"bytes",
"bytestring",
"either",
@ -1120,6 +1137,7 @@ dependencies = [
"futures-util",
"itoa",
"log",
"md5",
"memchr",
"percent-encoding",
"rand",

View File

@ -44,6 +44,9 @@ rsa = "0.3.0"
base64 = "0.13.0"
rand = "0.7"
itoa = "0.4.7"
atoi = "0.4.0"
byteorder = "1.4.2"
md5 = "0.7.0"
[dev-dependencies]
sqlx-core = { version = "0.6.0-pre", path = "../sqlx-core", features = ["_mock"] }

View File

@ -15,7 +15,9 @@ use sqlx_core::net::Stream as NetStream;
use sqlx_core::Error;
use sqlx_core::Result;
use crate::protocol::{Message, MessageType, Startup};
use crate::protocol::{
Authentication, BackendKeyData, Message, MessageType, Password, ReadyForQuery, Startup,
};
use crate::{PostgresConnectOptions, PostgresConnection};
macro_rules! connect {
@ -28,11 +30,11 @@ macro_rules! connect {
};
(@blocking @packet $self:ident) => {
$self.read_message()?;
$self.read_packet()?;
};
(@packet $self:ident) => {
$self.read_message_async().await?;
$self.read_packet_async().await?;
};
($(@$blocking:ident)? $options:ident) => {{
@ -83,39 +85,37 @@ macro_rules! connect {
let message: Message = connect!($(@$blocking)? @packet self_);
match message.r#type {
MessageType::Authentication => match message.decode()? {
// Authentication::Ok => {
Authentication::Ok => {
// the authentication exchange is successfully completed
// do nothing; no more information is required to continue
// }
}
// Authentication::CleartextPassword => {
// // The frontend must now send a [PasswordMessage] containing the
// // password in clear-text form.
Authentication::CleartextPassword => {
// The frontend must now send a [PasswordMessage] containing the
// password in clear-text form.
// stream
// .send(Password::Cleartext(
// options.password.as_deref().unwrap_or_default(),
// ))
// .await?;
// }
self_
.write_packet(&Password::Cleartext(
$options.get_password().unwrap_or_default(),
))?;
}
// Authentication::Md5Password(body) => {
// // The frontend must now send a [PasswordMessage] containing the
// // password (with user name) encrypted via MD5, then encrypted again
// // using the 4-byte random salt specified in the
// // [AuthenticationMD5Password] message.
Authentication::Md5Password(body) => {
// The frontend must now send a [PasswordMessage] containing the
// password (with user name) encrypted via MD5, then encrypted again
// using the 4-byte random salt specified in the
// [AuthenticationMD5Password] message.
// stream
// .send(Password::Md5 {
// username: &options.username,
// password: options.password.as_deref().unwrap_or_default(),
// salt: body.salt,
// })
// .await?;
// }
self_
.write_packet(&Password::Md5 {
username: $options.get_username().unwrap_or_default(),
password: $options.get_password().unwrap_or_default(),
salt: body.salt,
})?;
}
// Authentication::Sasl(body) => {
// sasl::authenticate(&mut stream, options, body).await?;
// sasl::authenticate(&mut stream, $options, body).await?;
// }
method => {
@ -126,28 +126,29 @@ macro_rules! connect {
}
},
// MessageFormat::BackendKeyData => {
// // provides secret-key data that the frontend must save if it wants to be
// // able to issue cancel requests later
MessageType::BackendKeyData => {
// provides secret-key data that the frontend must save if it wants to be
// able to issue cancel requests later
// let data: BackendKeyData = message.decode()?;
let data: BackendKeyData = message.decode()?;
// process_id = data.process_id;
// secret_key = data.secret_key;
// }
process_id = data.process_id;
secret_key = data.secret_key;
}
// MessageFormat::ReadyForQuery => {
// // start-up is completed. The frontend can now issue commands
// transaction_status =
// ReadyForQuery::decode(message.contents)?.transaction_status;
MessageType::ReadyForQuery => {
let ready: ReadyForQuery = message.decode()?;
// break;
// }
// start-up is completed. The frontend can now issue commands
transaction_status = ready.transaction_status;
break;
}
_ => {
return Err(Error::configuration_msg(format!(
"establish: unexpected message: {:?}",
message.format
message.r#type
)))
}
}

View File

@ -46,26 +46,23 @@ where
Ok(())
}
}
pub(crate) fn recv_message(&mut self) -> Result<Message> {
// all packets in postgres start with a 5-byte header
// this header contains the message type and the total length of the message
let mut header: Bytes = self.stream.take(5);
let r#type = MessageType::try_from(header.get_u8())?;
let size = (header.get_u32() - 4) as usize;
let contents = self.stream.take(size);
Ok(Message { r#type, contents })
}
fn recv_packet<'de, T>(&'de mut self, len: usize) -> Result<T>
where
T: Deserialize<'de, ()> + Debug,
{
macro_rules! read_packet {
($(@$blocking:ident)? $self:ident) => {{
loop {
let message = self.recv_message()?;
read_packet!($(@$blocking)? @stream $self, 0, 5);
let mut header: Bytes = $self.stream.take(5);
let r#type = MessageType::try_from(header.get_u8())?;
let size = (header.get_u32() - 4) as usize;
read_packet!($(@$blocking)? @stream $self, 4, size);
let contents = $self.stream.take(size);
let message = Message { r#type, contents };
match message.r#type {
MessageType::ErrorResponse => {
@ -127,28 +124,6 @@ where
return T::deserialize_with(message.contents, ());
}
}
}
macro_rules! read_packet {
($(@$blocking:ident)? $self:ident) => {{
// reads at least 4 bytes from the IO stream into the read buffer
read_packet!($(@$blocking)? @stream $self, 0, 4);
// the first 3 bytes will be the payload length of the packet (in LE)
// ALLOW: the max this len will be is 16M
#[allow(clippy::cast_possible_truncation)]
let payload_len: usize = $self.stream.get(0, 3).get_uint_le(3) as usize;
// read <payload_len> bytes _after_ the 4 byte packet header
// note that we have not yet told the stream we are done with any of
// these bytes yet. if this next read invocation were to never return (eg., the
// outer future was dropped), then the next time read_packet_async was called
// it will re-read the parsed-above packet header. Note that we have NOT
// mutated `self` _yet_. This is important.
read_packet!($(@$blocking)? @stream $self, 4, payload_len);
$self.recv_packet(payload_len)
}};
(@blocking @stream $self:ident, $offset:expr, $n:expr) => {

View File

@ -1,18 +1,26 @@
use std::convert::TryFrom;
use bytes::Bytes;
use bytes::{Buf, Bytes};
use sqlx_core::io::Deserialize;
use sqlx_core::Error;
use sqlx_core::Result;
mod authentication;
mod backend_key_data;
mod close;
mod notification;
mod password;
mod ready_for_query;
mod response;
mod startup;
mod terminate;
pub(crate) use authentication::{Authentication, AuthenticationSasl};
pub(crate) use backend_key_data::BackendKeyData;
pub(crate) use close::Close;
pub(crate) use notification::Notification;
pub(crate) use password::Password;
pub(crate) use ready_for_query::ReadyForQuery;
pub(crate) use response::{Notice, PgSeverity};
pub(crate) use startup::Startup;
pub(crate) use terminate::Terminate;
@ -28,7 +36,7 @@ pub enum MessageType {
ErrorResponse = b'E',
EmptyQueryResponse = b'I',
NotificationResponse = b'A',
KeyData = b'K',
BackendKeyData = b'K',
NoticeResponse = b'N',
Authentication = b'R',
ParameterStatus = b'S',
@ -70,7 +78,7 @@ impl TryFrom<u8> for MessageType {
b'E' => MessageType::ErrorResponse,
b'I' => MessageType::EmptyQueryResponse,
b'A' => MessageType::NotificationResponse,
b'K' => MessageType::KeyData,
b'K' => MessageType::BackendKeyData,
b'N' => MessageType::NoticeResponse,
b'R' => MessageType::Authentication,
b'S' => MessageType::ParameterStatus,
@ -89,3 +97,13 @@ impl TryFrom<u8> for MessageType {
})
}
}
impl Deserialize<'_, ()> for Message {
fn deserialize_with(mut buf: Bytes, _: ()) -> Result<Self> {
let r#type = MessageType::try_from(buf.get_u8())?;
let size = buf.get_u32() - 4;
let contents = buf.split_to(size as usize);
Ok(Message { r#type, contents })
}
}

View File

@ -0,0 +1,204 @@
use bytes::{Buf, Bytes};
use memchr::memchr;
use sqlx_core::io::Deserialize;
use sqlx_core::Error;
use sqlx_core::Result;
// On startup, the server sends an appropriate authentication request message,
// to which the frontend must reply with an appropriate authentication
// response message (such as a password).
// For all authentication methods except GSSAPI, SSPI and SASL, there is at
// most one request and one response. In some methods, no response at all is
// needed from the frontend, and so no authentication request occurs.
// For GSSAPI, SSPI and SASL, multiple exchanges of packets may
// be needed to complete the authentication.
// <https://www.postgresql.org/docs/devel/protocol-flow.html#id-1.10.5.7.3>
// <https://www.postgresql.org/docs/devel/protocol-message-formats.html>
#[derive(Debug)]
pub enum Authentication {
/// The authentication exchange is successfully completed.
Ok,
/// The frontend must now send a [PasswordMessage] containing the
/// password in clear-text form.
CleartextPassword,
/// The frontend must now send a [PasswordMessage] containing the
/// password (with user name) encrypted via MD5, then encrypted
/// again using the 4-byte random salt.
Md5Password(AuthenticationMd5Password),
/// The frontend must now initiate a SASL negotiation,
/// using one of the SASL mechanisms listed in the message.
///
/// The frontend will send a [SaslInitialResponse] with the name
/// of the selected mechanism, and the first part of the SASL
/// data stream in response to this.
///
/// If further messages are needed, the server will
/// respond with [Authentication::SaslContinue].
Sasl(AuthenticationSasl),
/// This message contains challenge data from the previous step of SASL negotiation.
///
/// The frontend must respond with a [SaslResponse] message.
SaslContinue(AuthenticationSaslContinue),
/// SASL authentication has completed with additional mechanism-specific
/// data for the client.
///
/// The server will next send [Authentication::Ok] to
/// indicate successful authentication.
SaslFinal(AuthenticationSaslFinal),
}
impl Deserialize<'_, ()> for Authentication {
fn deserialize_with(mut buf: Bytes, _: ()) -> Result<Self> {
Ok(match buf.get_u32() {
0 => Authentication::Ok,
3 => Authentication::CleartextPassword,
5 => {
let mut salt = [0; 4];
buf.copy_to_slice(&mut salt);
Authentication::Md5Password(AuthenticationMd5Password { salt })
}
10 => Authentication::Sasl(AuthenticationSasl(buf)),
11 => {
Authentication::SaslContinue(AuthenticationSaslContinue::deserialize_with(buf, ())?)
}
12 => Authentication::SaslFinal(AuthenticationSaslFinal::deserialize_with(buf, ())?),
ty => {
return Err(Error::configuration_msg(format!(
"unknown authentication method: {}",
ty
)));
}
})
}
}
/// Body of [Authentication::Md5Password].
#[derive(Debug)]
pub struct AuthenticationMd5Password {
pub salt: [u8; 4],
}
/// Body of [Authentication::Sasl].
#[derive(Debug)]
pub struct AuthenticationSasl(Bytes);
impl AuthenticationSasl {
#[inline]
pub fn mechanisms(&self) -> SaslMechanisms<'_> {
SaslMechanisms(&self.0)
}
}
/// An iterator over the SASL authentication mechanisms provided by the server.
pub struct SaslMechanisms<'a>(&'a [u8]);
impl<'a> Iterator for SaslMechanisms<'a> {
type Item = &'a str;
fn next(&mut self) -> Option<Self::Item> {
if !self.0.is_empty() && self.0[0] == b'\0' {
return None;
}
#[allow(unsafe_code)]
let mechanism = memchr(b'\0', self.0)
// UNSAFE: Postgres is expecte to return a valid UTF-8 string here
.and_then(|nul| Some(unsafe { std::str::from_utf8_unchecked(&self.0[..nul]) }))?;
self.0 = &self.0[(mechanism.len() + 1)..];
Some(mechanism)
}
}
#[derive(Debug)]
pub struct AuthenticationSaslContinue {
pub salt: Vec<u8>,
pub iterations: u32,
pub nonce: String,
pub message: String,
}
impl Deserialize<'_, ()> for AuthenticationSaslContinue {
fn deserialize_with(mut buf: Bytes, _: ()) -> Result<Self> {
let mut iterations: u32 = 4096;
let mut salt = Vec::new();
let mut nonce = Bytes::new();
// [Example]
// r=/z+giZiTxAH7r8sNAeHr7cvpqV3uo7G/bJBIJO3pjVM7t3ng,s=4UV68bIkC8f9/X8xH7aPhg==,i=4096
for item in buf.split(|b| *b == b',') {
let key = item[0];
let value = &item[2..];
match key {
b'r' => {
nonce = buf.slice_ref(value);
}
b'i' => {
iterations = atoi::atoi(value).unwrap_or(4096);
}
b's' => {
// TODO: Map error correctly
salt = base64::decode(value).unwrap();
}
_ => {}
}
}
#[allow(unsafe_code)]
Ok(Self {
iterations,
salt,
// UNSAFE: Postgres is expected to return a valid UTF-8 string here
nonce: unsafe { String::from_utf8_unchecked((*nonce).to_vec()) },
// UNSAFE: Postgres is expected to return a valid UTF-8 string here
message: unsafe { String::from_utf8_unchecked((*buf).to_vec()) },
})
}
}
#[derive(Debug)]
pub struct AuthenticationSaslFinal {
pub verifier: Vec<u8>,
}
impl Deserialize<'_, ()> for AuthenticationSaslFinal {
fn deserialize_with(mut buf: Bytes, _: ()) -> Result<Self> {
let mut verifier = Vec::new();
for item in buf.split(|b| *b == b',') {
let key = item[0];
let value = &item[2..];
if let b'v' = key {
// TODO: Map error correctly
verifier = base64::decode(value).unwrap();
}
}
Ok(Self { verifier })
}
}

View File

@ -0,0 +1,25 @@
use byteorder::{BigEndian, ByteOrder};
use bytes::Bytes;
use sqlx_core::io::Deserialize;
use sqlx_core::Error;
use sqlx_core::Result;
/// Contains cancellation key data. The frontend must save these values if it
/// wishes to be able to issue `CancelRequest` messages later.
#[derive(Debug)]
pub struct BackendKeyData {
/// The process ID of this database.
pub process_id: u32,
/// The secret key of this database.
pub secret_key: u32,
}
impl Deserialize<'_, ()> for BackendKeyData {
fn deserialize_with(mut buf: Bytes, _: ()) -> Result<Self> {
let process_id = BigEndian::read_u32(&buf);
let secret_key = BigEndian::read_u32(&buf[4..]);
Ok(Self { process_id, secret_key })
}
}

View File

@ -0,0 +1,64 @@
use std::fmt::Write;
use sqlx_core::io::Serialize;
use sqlx_core::io::WriteExt;
use sqlx_core::Result;
use crate::io::PgBufMutExt;
#[derive(Debug)]
pub enum Password<'a> {
Cleartext(&'a str),
Md5 { password: &'a str, username: &'a str, salt: [u8; 4] },
}
impl Password<'_> {
#[inline]
fn len(&self) -> usize {
match self {
Password::Cleartext(s) => s.len() + 5,
Password::Md5 { .. } => 35 + 5,
}
}
}
impl Serialize<'_, ()> for Password<'_> {
fn serialize_with(&self, buf: &mut Vec<u8>, _: ()) -> Result<()> {
buf.reserve(1 + 4 + self.len());
buf.push(b'p');
buf.write_length_prefixed(|buf| {
match self {
Password::Cleartext(password) => {
buf.write_str_nul(password);
}
Password::Md5 { username, password, salt } => {
// The actual `PasswordMessage` can be comwriteed in SQL as
// `concat('md5', md5(concat(md5(concat(password, username)), random-salt)))`.
// Keep in mind the md5() function returns its result as a hex string.
let digest = md5::compute(password);
let digest = md5::compute(username);
let mut outwrite = String::with_capacity(35);
let _ = write!(outwrite, "{:x}", digest);
let digest = md5::compute(&outwrite);
let digest = md5::compute(salt);
outwrite.clear();
let _ = write!(outwrite, "md5{:x}", digest);
buf.write_str_nul(&outwrite);
}
}
});
Ok(())
}
}

View File

@ -0,0 +1,51 @@
use std::convert::TryFrom;
use bytes::Bytes;
use sqlx_core::io::Deserialize;
use sqlx_core::Error;
use sqlx_core::Result;
#[derive(Debug)]
#[repr(u8)]
pub(crate) enum TransactionStatus {
/// Not in a transaction block.
Idle = b'I',
/// In a transaction block.
Transaction = b'T',
/// In a _failed_ transaction block. Queries will be rejected until block is ended.
Error = b'E',
}
impl TryFrom<u8> for TransactionStatus {
type Error = Error;
fn try_from(value: u8) -> Result<Self> {
match value {
b'I' => Ok(TransactionStatus::Idle),
b'T' => Ok(TransactionStatus::Transaction),
b'E' => Ok(TransactionStatus::Error),
status => {
return Err(Error::configuration_msg(format!(
"unknown transaction status: {:?}",
status as char,
)));
}
}
}
}
#[derive(Debug)]
pub(crate) struct ReadyForQuery {
pub transaction_status: TransactionStatus,
}
impl Deserialize<'_, ()> for ReadyForQuery {
fn deserialize_with(mut buf: Bytes, _: ()) -> Result<Self> {
let transaction_status = TransactionStatus::try_from(buf[0])?;
Ok(Self { transaction_status })
}
}