From c8f7601ad186e78ae7e074a966275276349631f2 Mon Sep 17 00:00:00 2001 From: Daniel Akhterov Date: Sat, 23 Jan 2021 13:27:52 -0800 Subject: [PATCH] feat: implement `read_packet` for postgres --- Cargo.lock | 18 ++ sqlx-postgres/Cargo.toml | 3 + sqlx-postgres/src/connection/connect.rs | 85 ++++---- sqlx-postgres/src/connection/stream.rs | 55 ++--- sqlx-postgres/src/protocol.rs | 24 ++- sqlx-postgres/src/protocol/authentication.rs | 204 ++++++++++++++++++ .../src/protocol/backend_key_data.rs | 25 +++ sqlx-postgres/src/protocol/password.rs | 64 ++++++ sqlx-postgres/src/protocol/ready_for_query.rs | 51 +++++ 9 files changed, 444 insertions(+), 85 deletions(-) create mode 100644 sqlx-postgres/src/protocol/authentication.rs create mode 100644 sqlx-postgres/src/protocol/backend_key_data.rs create mode 100644 sqlx-postgres/src/protocol/password.rs create mode 100644 sqlx-postgres/src/protocol/ready_for_query.rs diff --git a/Cargo.lock b/Cargo.lock index 3450b9ba..82327364 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -150,6 +150,15 @@ version = "4.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e91831deabf0d6d7ec49552e489aed63b7456a7a3c46cff62adad428110b0af0" +[[package]] +name = "atoi" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "616896e05fc0e2649463a93a15183c6a16bf03413a7af88ef1285ddedfa9cda5" +dependencies = [ + "num-traits", +] + [[package]] name = "atomic-waker" version = "1.0.0" @@ -631,6 +640,12 @@ version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ffc5c5338469d4d3ea17d269fa8ea3512ad247247c30bd2df69e68309ed0a08" +[[package]] +name = "md5" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771" + [[package]] name = "memchr" version = "2.3.4" @@ -1110,8 +1125,10 @@ name = "sqlx-postgres" version = "0.6.0-pre" dependencies = [ "anyhow", + "atoi", "base64", "bitflags", + "byteorder", "bytes", "bytestring", "either", @@ -1120,6 +1137,7 @@ dependencies = [ "futures-util", "itoa", "log", + "md5", "memchr", "percent-encoding", "rand", diff --git a/sqlx-postgres/Cargo.toml b/sqlx-postgres/Cargo.toml index 40abec0e..57e48372 100644 --- a/sqlx-postgres/Cargo.toml +++ b/sqlx-postgres/Cargo.toml @@ -44,6 +44,9 @@ rsa = "0.3.0" base64 = "0.13.0" rand = "0.7" itoa = "0.4.7" +atoi = "0.4.0" +byteorder = "1.4.2" +md5 = "0.7.0" [dev-dependencies] sqlx-core = { version = "0.6.0-pre", path = "../sqlx-core", features = ["_mock"] } diff --git a/sqlx-postgres/src/connection/connect.rs b/sqlx-postgres/src/connection/connect.rs index 33fe7396..3e01c131 100644 --- a/sqlx-postgres/src/connection/connect.rs +++ b/sqlx-postgres/src/connection/connect.rs @@ -15,7 +15,9 @@ use sqlx_core::net::Stream as NetStream; use sqlx_core::Error; use sqlx_core::Result; -use crate::protocol::{Message, MessageType, Startup}; +use crate::protocol::{ + Authentication, BackendKeyData, Message, MessageType, Password, ReadyForQuery, Startup, +}; use crate::{PostgresConnectOptions, PostgresConnection}; macro_rules! connect { @@ -28,11 +30,11 @@ macro_rules! connect { }; (@blocking @packet $self:ident) => { - $self.read_message()?; + $self.read_packet()?; }; (@packet $self:ident) => { - $self.read_message_async().await?; + $self.read_packet_async().await?; }; ($(@$blocking:ident)? $options:ident) => {{ @@ -83,39 +85,37 @@ macro_rules! connect { let message: Message = connect!($(@$blocking)? @packet self_); match message.r#type { MessageType::Authentication => match message.decode()? { - // Authentication::Ok => { + Authentication::Ok => { // the authentication exchange is successfully completed // do nothing; no more information is required to continue - // } + } - // Authentication::CleartextPassword => { - // // The frontend must now send a [PasswordMessage] containing the - // // password in clear-text form. + Authentication::CleartextPassword => { + // The frontend must now send a [PasswordMessage] containing the + // password in clear-text form. - // stream - // .send(Password::Cleartext( - // options.password.as_deref().unwrap_or_default(), - // )) - // .await?; - // } + self_ + .write_packet(&Password::Cleartext( + $options.get_password().unwrap_or_default(), + ))?; + } - // Authentication::Md5Password(body) => { - // // The frontend must now send a [PasswordMessage] containing the - // // password (with user name) encrypted via MD5, then encrypted again - // // using the 4-byte random salt specified in the - // // [AuthenticationMD5Password] message. + Authentication::Md5Password(body) => { + // The frontend must now send a [PasswordMessage] containing the + // password (with user name) encrypted via MD5, then encrypted again + // using the 4-byte random salt specified in the + // [AuthenticationMD5Password] message. - // stream - // .send(Password::Md5 { - // username: &options.username, - // password: options.password.as_deref().unwrap_or_default(), - // salt: body.salt, - // }) - // .await?; - // } + self_ + .write_packet(&Password::Md5 { + username: $options.get_username().unwrap_or_default(), + password: $options.get_password().unwrap_or_default(), + salt: body.salt, + })?; + } // Authentication::Sasl(body) => { - // sasl::authenticate(&mut stream, options, body).await?; + // sasl::authenticate(&mut stream, $options, body).await?; // } method => { @@ -126,28 +126,29 @@ macro_rules! connect { } }, - // MessageFormat::BackendKeyData => { - // // provides secret-key data that the frontend must save if it wants to be - // // able to issue cancel requests later + MessageType::BackendKeyData => { + // provides secret-key data that the frontend must save if it wants to be + // able to issue cancel requests later - // let data: BackendKeyData = message.decode()?; + let data: BackendKeyData = message.decode()?; - // process_id = data.process_id; - // secret_key = data.secret_key; - // } + process_id = data.process_id; + secret_key = data.secret_key; + } - // MessageFormat::ReadyForQuery => { - // // start-up is completed. The frontend can now issue commands - // transaction_status = - // ReadyForQuery::decode(message.contents)?.transaction_status; + MessageType::ReadyForQuery => { + let ready: ReadyForQuery = message.decode()?; - // break; - // } + // start-up is completed. The frontend can now issue commands + transaction_status = ready.transaction_status; + + break; + } _ => { return Err(Error::configuration_msg(format!( "establish: unexpected message: {:?}", - message.format + message.r#type ))) } } diff --git a/sqlx-postgres/src/connection/stream.rs b/sqlx-postgres/src/connection/stream.rs index 9529fede..7559cb4f 100644 --- a/sqlx-postgres/src/connection/stream.rs +++ b/sqlx-postgres/src/connection/stream.rs @@ -46,26 +46,23 @@ where Ok(()) } +} - pub(crate) fn recv_message(&mut self) -> Result { - // all packets in postgres start with a 5-byte header - // this header contains the message type and the total length of the message - let mut header: Bytes = self.stream.take(5); - - let r#type = MessageType::try_from(header.get_u8())?; - let size = (header.get_u32() - 4) as usize; - - let contents = self.stream.take(size); - - Ok(Message { r#type, contents }) - } - - fn recv_packet<'de, T>(&'de mut self, len: usize) -> Result - where - T: Deserialize<'de, ()> + Debug, - { +macro_rules! read_packet { + ($(@$blocking:ident)? $self:ident) => {{ loop { - let message = self.recv_message()?; + read_packet!($(@$blocking)? @stream $self, 0, 5); + + let mut header: Bytes = $self.stream.take(5); + + let r#type = MessageType::try_from(header.get_u8())?; + let size = (header.get_u32() - 4) as usize; + + read_packet!($(@$blocking)? @stream $self, 4, size); + + let contents = $self.stream.take(size); + + let message = Message { r#type, contents }; match message.r#type { MessageType::ErrorResponse => { @@ -127,28 +124,6 @@ where return T::deserialize_with(message.contents, ()); } - } -} - -macro_rules! read_packet { - ($(@$blocking:ident)? $self:ident) => {{ - // reads at least 4 bytes from the IO stream into the read buffer - read_packet!($(@$blocking)? @stream $self, 0, 4); - - // the first 3 bytes will be the payload length of the packet (in LE) - // ALLOW: the max this len will be is 16M - #[allow(clippy::cast_possible_truncation)] - let payload_len: usize = $self.stream.get(0, 3).get_uint_le(3) as usize; - - // read bytes _after_ the 4 byte packet header - // note that we have not yet told the stream we are done with any of - // these bytes yet. if this next read invocation were to never return (eg., the - // outer future was dropped), then the next time read_packet_async was called - // it will re-read the parsed-above packet header. Note that we have NOT - // mutated `self` _yet_. This is important. - read_packet!($(@$blocking)? @stream $self, 4, payload_len); - - $self.recv_packet(payload_len) }}; (@blocking @stream $self:ident, $offset:expr, $n:expr) => { diff --git a/sqlx-postgres/src/protocol.rs b/sqlx-postgres/src/protocol.rs index fe00f0b0..5c37dfed 100644 --- a/sqlx-postgres/src/protocol.rs +++ b/sqlx-postgres/src/protocol.rs @@ -1,18 +1,26 @@ use std::convert::TryFrom; -use bytes::Bytes; +use bytes::{Buf, Bytes}; use sqlx_core::io::Deserialize; use sqlx_core::Error; use sqlx_core::Result; +mod authentication; +mod backend_key_data; mod close; mod notification; +mod password; +mod ready_for_query; mod response; mod startup; mod terminate; +pub(crate) use authentication::{Authentication, AuthenticationSasl}; +pub(crate) use backend_key_data::BackendKeyData; pub(crate) use close::Close; pub(crate) use notification::Notification; +pub(crate) use password::Password; +pub(crate) use ready_for_query::ReadyForQuery; pub(crate) use response::{Notice, PgSeverity}; pub(crate) use startup::Startup; pub(crate) use terminate::Terminate; @@ -28,7 +36,7 @@ pub enum MessageType { ErrorResponse = b'E', EmptyQueryResponse = b'I', NotificationResponse = b'A', - KeyData = b'K', + BackendKeyData = b'K', NoticeResponse = b'N', Authentication = b'R', ParameterStatus = b'S', @@ -70,7 +78,7 @@ impl TryFrom for MessageType { b'E' => MessageType::ErrorResponse, b'I' => MessageType::EmptyQueryResponse, b'A' => MessageType::NotificationResponse, - b'K' => MessageType::KeyData, + b'K' => MessageType::BackendKeyData, b'N' => MessageType::NoticeResponse, b'R' => MessageType::Authentication, b'S' => MessageType::ParameterStatus, @@ -89,3 +97,13 @@ impl TryFrom for MessageType { }) } } + +impl Deserialize<'_, ()> for Message { + fn deserialize_with(mut buf: Bytes, _: ()) -> Result { + let r#type = MessageType::try_from(buf.get_u8())?; + let size = buf.get_u32() - 4; + let contents = buf.split_to(size as usize); + + Ok(Message { r#type, contents }) + } +} diff --git a/sqlx-postgres/src/protocol/authentication.rs b/sqlx-postgres/src/protocol/authentication.rs new file mode 100644 index 00000000..8479c9d4 --- /dev/null +++ b/sqlx-postgres/src/protocol/authentication.rs @@ -0,0 +1,204 @@ +use bytes::{Buf, Bytes}; +use memchr::memchr; +use sqlx_core::io::Deserialize; +use sqlx_core::Error; +use sqlx_core::Result; + +// On startup, the server sends an appropriate authentication request message, +// to which the frontend must reply with an appropriate authentication +// response message (such as a password). + +// For all authentication methods except GSSAPI, SSPI and SASL, there is at +// most one request and one response. In some methods, no response at all is +// needed from the frontend, and so no authentication request occurs. + +// For GSSAPI, SSPI and SASL, multiple exchanges of packets may +// be needed to complete the authentication. + +// +// + +#[derive(Debug)] +pub enum Authentication { + /// The authentication exchange is successfully completed. + Ok, + + /// The frontend must now send a [PasswordMessage] containing the + /// password in clear-text form. + CleartextPassword, + + /// The frontend must now send a [PasswordMessage] containing the + /// password (with user name) encrypted via MD5, then encrypted + /// again using the 4-byte random salt. + Md5Password(AuthenticationMd5Password), + + /// The frontend must now initiate a SASL negotiation, + /// using one of the SASL mechanisms listed in the message. + /// + /// The frontend will send a [SaslInitialResponse] with the name + /// of the selected mechanism, and the first part of the SASL + /// data stream in response to this. + /// + /// If further messages are needed, the server will + /// respond with [Authentication::SaslContinue]. + Sasl(AuthenticationSasl), + + /// This message contains challenge data from the previous step of SASL negotiation. + /// + /// The frontend must respond with a [SaslResponse] message. + SaslContinue(AuthenticationSaslContinue), + + /// SASL authentication has completed with additional mechanism-specific + /// data for the client. + /// + /// The server will next send [Authentication::Ok] to + /// indicate successful authentication. + SaslFinal(AuthenticationSaslFinal), +} + +impl Deserialize<'_, ()> for Authentication { + fn deserialize_with(mut buf: Bytes, _: ()) -> Result { + Ok(match buf.get_u32() { + 0 => Authentication::Ok, + + 3 => Authentication::CleartextPassword, + + 5 => { + let mut salt = [0; 4]; + buf.copy_to_slice(&mut salt); + + Authentication::Md5Password(AuthenticationMd5Password { salt }) + } + + 10 => Authentication::Sasl(AuthenticationSasl(buf)), + + 11 => { + Authentication::SaslContinue(AuthenticationSaslContinue::deserialize_with(buf, ())?) + } + + 12 => Authentication::SaslFinal(AuthenticationSaslFinal::deserialize_with(buf, ())?), + + ty => { + return Err(Error::configuration_msg(format!( + "unknown authentication method: {}", + ty + ))); + } + }) + } +} + +/// Body of [Authentication::Md5Password]. +#[derive(Debug)] +pub struct AuthenticationMd5Password { + pub salt: [u8; 4], +} + +/// Body of [Authentication::Sasl]. +#[derive(Debug)] +pub struct AuthenticationSasl(Bytes); + +impl AuthenticationSasl { + #[inline] + pub fn mechanisms(&self) -> SaslMechanisms<'_> { + SaslMechanisms(&self.0) + } +} + +/// An iterator over the SASL authentication mechanisms provided by the server. +pub struct SaslMechanisms<'a>(&'a [u8]); + +impl<'a> Iterator for SaslMechanisms<'a> { + type Item = &'a str; + + fn next(&mut self) -> Option { + if !self.0.is_empty() && self.0[0] == b'\0' { + return None; + } + + #[allow(unsafe_code)] + let mechanism = memchr(b'\0', self.0) + // UNSAFE: Postgres is expecte to return a valid UTF-8 string here + .and_then(|nul| Some(unsafe { std::str::from_utf8_unchecked(&self.0[..nul]) }))?; + + self.0 = &self.0[(mechanism.len() + 1)..]; + + Some(mechanism) + } +} + +#[derive(Debug)] +pub struct AuthenticationSaslContinue { + pub salt: Vec, + pub iterations: u32, + pub nonce: String, + pub message: String, +} + +impl Deserialize<'_, ()> for AuthenticationSaslContinue { + fn deserialize_with(mut buf: Bytes, _: ()) -> Result { + let mut iterations: u32 = 4096; + let mut salt = Vec::new(); + let mut nonce = Bytes::new(); + + // [Example] + // r=/z+giZiTxAH7r8sNAeHr7cvpqV3uo7G/bJBIJO3pjVM7t3ng,s=4UV68bIkC8f9/X8xH7aPhg==,i=4096 + + for item in buf.split(|b| *b == b',') { + let key = item[0]; + let value = &item[2..]; + + match key { + b'r' => { + nonce = buf.slice_ref(value); + } + + b'i' => { + iterations = atoi::atoi(value).unwrap_or(4096); + } + + b's' => { + // TODO: Map error correctly + salt = base64::decode(value).unwrap(); + } + + _ => {} + } + } + + #[allow(unsafe_code)] + Ok(Self { + iterations, + salt, + + // UNSAFE: Postgres is expected to return a valid UTF-8 string here + nonce: unsafe { String::from_utf8_unchecked((*nonce).to_vec()) }, + + // UNSAFE: Postgres is expected to return a valid UTF-8 string here + message: unsafe { String::from_utf8_unchecked((*buf).to_vec()) }, + }) + } +} + +#[derive(Debug)] +pub struct AuthenticationSaslFinal { + pub verifier: Vec, +} + +impl Deserialize<'_, ()> for AuthenticationSaslFinal { + fn deserialize_with(mut buf: Bytes, _: ()) -> Result { + let mut verifier = Vec::new(); + + for item in buf.split(|b| *b == b',') { + let key = item[0]; + let value = &item[2..]; + + if let b'v' = key { + // TODO: Map error correctly + verifier = base64::decode(value).unwrap(); + } + } + + Ok(Self { verifier }) + } +} diff --git a/sqlx-postgres/src/protocol/backend_key_data.rs b/sqlx-postgres/src/protocol/backend_key_data.rs new file mode 100644 index 00000000..0e9ebbd9 --- /dev/null +++ b/sqlx-postgres/src/protocol/backend_key_data.rs @@ -0,0 +1,25 @@ +use byteorder::{BigEndian, ByteOrder}; +use bytes::Bytes; +use sqlx_core::io::Deserialize; +use sqlx_core::Error; +use sqlx_core::Result; + +/// Contains cancellation key data. The frontend must save these values if it +/// wishes to be able to issue `CancelRequest` messages later. +#[derive(Debug)] +pub struct BackendKeyData { + /// The process ID of this database. + pub process_id: u32, + + /// The secret key of this database. + pub secret_key: u32, +} + +impl Deserialize<'_, ()> for BackendKeyData { + fn deserialize_with(mut buf: Bytes, _: ()) -> Result { + let process_id = BigEndian::read_u32(&buf); + let secret_key = BigEndian::read_u32(&buf[4..]); + + Ok(Self { process_id, secret_key }) + } +} diff --git a/sqlx-postgres/src/protocol/password.rs b/sqlx-postgres/src/protocol/password.rs new file mode 100644 index 00000000..665bbe43 --- /dev/null +++ b/sqlx-postgres/src/protocol/password.rs @@ -0,0 +1,64 @@ +use std::fmt::Write; + +use sqlx_core::io::Serialize; +use sqlx_core::io::WriteExt; +use sqlx_core::Result; + +use crate::io::PgBufMutExt; + +#[derive(Debug)] +pub enum Password<'a> { + Cleartext(&'a str), + + Md5 { password: &'a str, username: &'a str, salt: [u8; 4] }, +} + +impl Password<'_> { + #[inline] + fn len(&self) -> usize { + match self { + Password::Cleartext(s) => s.len() + 5, + Password::Md5 { .. } => 35 + 5, + } + } +} + +impl Serialize<'_, ()> for Password<'_> { + fn serialize_with(&self, buf: &mut Vec, _: ()) -> Result<()> { + buf.reserve(1 + 4 + self.len()); + buf.push(b'p'); + + buf.write_length_prefixed(|buf| { + match self { + Password::Cleartext(password) => { + buf.write_str_nul(password); + } + + Password::Md5 { username, password, salt } => { + // The actual `PasswordMessage` can be comwriteed in SQL as + // `concat('md5', md5(concat(md5(concat(password, username)), random-salt)))`. + + // Keep in mind the md5() function returns its result as a hex string. + + let digest = md5::compute(password); + let digest = md5::compute(username); + + let mut outwrite = String::with_capacity(35); + + let _ = write!(outwrite, "{:x}", digest); + + let digest = md5::compute(&outwrite); + let digest = md5::compute(salt); + + outwrite.clear(); + + let _ = write!(outwrite, "md5{:x}", digest); + + buf.write_str_nul(&outwrite); + } + } + }); + + Ok(()) + } +} diff --git a/sqlx-postgres/src/protocol/ready_for_query.rs b/sqlx-postgres/src/protocol/ready_for_query.rs new file mode 100644 index 00000000..81a4159c --- /dev/null +++ b/sqlx-postgres/src/protocol/ready_for_query.rs @@ -0,0 +1,51 @@ +use std::convert::TryFrom; + +use bytes::Bytes; +use sqlx_core::io::Deserialize; +use sqlx_core::Error; +use sqlx_core::Result; + +#[derive(Debug)] +#[repr(u8)] +pub(crate) enum TransactionStatus { + /// Not in a transaction block. + Idle = b'I', + + /// In a transaction block. + Transaction = b'T', + + /// In a _failed_ transaction block. Queries will be rejected until block is ended. + Error = b'E', +} + +impl TryFrom for TransactionStatus { + type Error = Error; + + fn try_from(value: u8) -> Result { + match value { + b'I' => Ok(TransactionStatus::Idle), + b'T' => Ok(TransactionStatus::Transaction), + b'E' => Ok(TransactionStatus::Error), + + status => { + return Err(Error::configuration_msg(format!( + "unknown transaction status: {:?}", + status as char, + ))); + } + } + } +} + +#[derive(Debug)] +pub(crate) struct ReadyForQuery { + pub transaction_status: TransactionStatus, +} + +impl Deserialize<'_, ()> for ReadyForQuery { + fn deserialize_with(mut buf: Bytes, _: ()) -> Result { + let transaction_status = TransactionStatus::try_from(buf[0])?; + + Ok(Self { transaction_status }) + } +}