diff --git a/sqlx-postgres/Cargo.toml b/sqlx-postgres/Cargo.toml index d1fbd0d7..dee0d72b 100644 --- a/sqlx-postgres/Cargo.toml +++ b/sqlx-postgres/Cargo.toml @@ -38,6 +38,7 @@ futures-io = { version = "0.3", optional = true } bytes = "1.0" memchr = "2.3" bitflags = "1.2" +base64 = "0.13.0" [dev-dependencies] sqlx-core = { version = "0.6.0-pre", path = "../sqlx-core", features = ["_mock"] } diff --git a/sqlx-postgres/src/connection.rs b/sqlx-postgres/src/connection.rs index 263a8a71..35fab1a9 100644 --- a/sqlx-postgres/src/connection.rs +++ b/sqlx-postgres/src/connection.rs @@ -2,10 +2,13 @@ use std::fmt::{self, Debug, Formatter}; #[cfg(feature = "async")] use futures_util::future::{BoxFuture, FutureExt, TryFutureExt}; +use sqlx_core::net::Stream as NetStream; use sqlx_core::{Close, Connect, Connection, Runtime}; use crate::stream::PgStream; -use crate::Postgres; +use crate::{PgConnectOptions, Postgres}; + +mod connect; /// A single connection (also known as a session) to a /// PostgreSQL database server. @@ -64,10 +67,10 @@ impl Connection for PgConnection { } impl Connect for PgConnection { - type Options = PostgresConnectOptions; + type Options = PgConnectOptions; #[cfg(feature = "async")] - fn connect_with(options: &PostgresConnectOptions) -> BoxFuture<'_, sqlx_core::Result> + fn connect_with(options: &PgConnectOptions) -> BoxFuture<'_, sqlx_core::Result> where Self: Sized, Rt: sqlx_core::Async, @@ -82,6 +85,52 @@ impl Close for PgConnection { where Rt: sqlx_core::Async, { - todo!() + Box::pin(async move { + self.stream.close_async().await?; + + Ok(()) + }) + } +} + +#[cfg(feature = "blocking")] +mod blocking { + use sqlx_core::blocking::{Close, Connect, Connection, Runtime}; + + use super::{PgConnectOptions, PgConnection, Postgres}; + + impl Connection for PgConnection { + #[inline] + fn ping(&mut self) -> sqlx_core::Result<()> { + todo!() + } + + fn describe<'x, 'e, 'q>( + &'e mut self, + query: &'q str, + ) -> sqlx_core::Result> + where + 'e: 'x, + 'q: 'x, + { + todo!() + } + } + + impl Connect for PgConnection { + #[inline] + fn connect_with(options: &PgConnectOptions) -> sqlx_core::Result + where + Self: Sized, + { + Self::connect_blocking(options) + } + } + + impl Close for PgConnection { + #[inline] + fn close(mut self) -> sqlx_core::Result<()> { + self.stream.close_blocking() + } } } diff --git a/sqlx-postgres/src/connection/connect.rs b/sqlx-postgres/src/connection/connect.rs new file mode 100644 index 00000000..3b1b3810 --- /dev/null +++ b/sqlx-postgres/src/connection/connect.rs @@ -0,0 +1,129 @@ +//! Implements start-up flow. +//! +//! To begin a session, a frontend opens a connection to the server +//! and sends a startup message. +//! +//! The server then sends an appropriate authentication request message, to +//! which the frontend must reply with an appropriate authentication +//! response message. +//! +//! The authentication cycle ends with the server either rejecting +//! the connection attempt (ErrorResponse), or sending AuthenticationOk. +//! +//! + +use sqlx_core::net::Stream as NetStream; +use sqlx_core::{Error, Result, Runtime}; + +use crate::protocol::backend::{Authentication, BackendMessage, BackendMessageType}; +use crate::protocol::frontend::Startup; +use crate::{PgClientError, PgConnectOptions, PgConnection}; + +impl PgConnection { + fn write_startup_message(&mut self, options: &PgConnectOptions) -> Result<()> { + let params = vec![ + ("user", options.get_username()), + ("database", options.get_database()), + ("application_name", options.get_application_name()), + // sets the text display format for date and time values + // as well as the rules for interpreting ambiguous date input values + ("DateStyle", Some("ISO, MDY")), + // sets the client-side encoding (charset) + // NOTE: this must not be changed, too much in the driver depends on this being set to UTF-8 + ("client_encoding", Some("UTF8")), + // sets the timezone for displaying and interpreting time stamps + // NOTE: this is only used to assume timestamptz values are in UTC + ("TimeZone", Some("UTC")), + ]; + + self.stream.write_message(&Startup(¶ms)) + } + + fn handle_startup_response( + &mut self, + options: &PgConnectOptions, + message: BackendMessage, + ) -> Result { + match message.ty { + BackendMessageType::Authentication => match message.deserialize()? { + Authentication::Ok => { + return Ok(true); + } + + Authentication::Md5Password(_) => { + todo!("md5") + } + + Authentication::CleartextPassword => { + todo!("cleartext") + } + + Authentication::Sasl(_) => todo!("sasl"), + Authentication::SaslContinue(_) => todo!("sasl continue"), + Authentication::SaslFinal(_) => todo!("sasl final"), + }, + + ty => { + return Err(Error::client(PgClientError::UnexpectedMessageType { + ty: ty as u8, + context: "starting up", + })); + } + } + } +} + +macro_rules! impl_connect { + (@blocking @new $options:ident) => { + NetStream::connect($options.address.as_ref())? + }; + + (@new $options:ident) => { + NetStream::connect_async($options.address.as_ref()).await? + }; + + ($(@$blocking:ident)? $options:ident) => {{ + // open a network stream to the database server + let stream = impl_connect!($(@$blocking)? @new $options); + + // construct a around the network stream + // wraps the stream in a to buffer read and write + let mut self_ = Self::new(stream); + + // to begin a session, a frontend should send a startup message + // this is built up of various startup parameters that control the connection + self_.write_startup_message($options)?; + + // the server then uses this information and the contents of + // its configuration files (such as pg_hba.conf) to determine whether the connection is + // provisionally acceptable, and what additional + // authentication is required (if any). + loop { + let message = read_message!($(@$blocking)? self_.stream); + if self_.handle_startup_response($options, message)? { + // complete, successful authentication + break; + } + } + + Ok(self_) + }}; +} + +impl PgConnection { + #[cfg(feature = "async")] + pub(crate) async fn connect_async(options: &PgConnectOptions) -> Result + where + Rt: sqlx_core::Async, + { + impl_connect!(options) + } + + #[cfg(feature = "blocking")] + pub(crate) fn connect_blocking(options: &PgConnectOptions) -> Result + where + Rt: sqlx_core::blocking::Runtime, + { + impl_connect!(@blocking options) + } +} diff --git a/sqlx-postgres/src/error.rs b/sqlx-postgres/src/error.rs new file mode 100644 index 00000000..3fee4428 --- /dev/null +++ b/sqlx-postgres/src/error.rs @@ -0,0 +1,4 @@ +mod client; +mod database; + +pub use client::PgClientError; diff --git a/sqlx-postgres/src/error/client.rs b/sqlx-postgres/src/error/client.rs new file mode 100644 index 00000000..451d5bff --- /dev/null +++ b/sqlx-postgres/src/error/client.rs @@ -0,0 +1,40 @@ +use crate::protocol::backend::BackendMessageType; +use sqlx_core::ClientError; +use std::error::Error as StdError; +use std::fmt::{self, Display, Formatter}; +use std::str::Utf8Error; + +#[derive(Debug)] +#[non_exhaustive] +pub enum PgClientError { + // attempting to interpret data from postgres as UTF-8, when it should + // be UTF-8, but for some reason (data corruption?) it is not + NotUtf8(Utf8Error), + UnknownAuthenticationMethod(u32), + UnknownMessageType(u8), + UnexpectedMessageType { ty: u8, context: &'static str }, +} + +impl Display for PgClientError { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Self::NotUtf8(source) => write!(f, "{}", source), + + Self::UnknownAuthenticationMethod(method) => { + write!(f, "unknown authentication method: {}", method) + } + + Self::UnknownMessageType(ty) => { + write!(f, "unknown protocol message type: '{}' ({})", *ty as char, *ty) + } + + Self::UnexpectedMessageType { ty, context } => { + write!(f, "unexpected message {:?} '{}' while {}", ty, (*ty as u8 as char), context) + } + } + } +} + +impl StdError for PgClientError {} + +impl ClientError for PgClientError {} diff --git a/sqlx-postgres/src/error/database.rs b/sqlx-postgres/src/error/database.rs new file mode 100644 index 00000000..e69de29b diff --git a/sqlx-postgres/src/io.rs b/sqlx-postgres/src/io.rs new file mode 100644 index 00000000..ad5ffda3 --- /dev/null +++ b/sqlx-postgres/src/io.rs @@ -0,0 +1,3 @@ +mod write; + +pub(crate) use write::PgWriteExt; diff --git a/sqlx-postgres/src/io/write.rs b/sqlx-postgres/src/io/write.rs new file mode 100644 index 00000000..31922611 --- /dev/null +++ b/sqlx-postgres/src/io/write.rs @@ -0,0 +1,31 @@ +use sqlx_core::io::WriteExt; +use sqlx_core::Result; + +pub trait PgWriteExt: WriteExt { + fn write_len_prefixed(&mut self, f: F) -> Result<()> + where + F: FnOnce(&mut Vec) -> Result<()>; +} + +impl PgWriteExt for Vec { + /// Writes a length-prefixed message, this is used when encoding nearly + /// all messages as postgres wants us to send the length of the + /// often-variable-sized messages up front. + fn write_len_prefixed(&mut self, f: F) -> Result<()> + where + F: FnOnce(&mut Vec) -> Result<()>, + { + // reserve space to write the prefixed length + let offset = self.len(); + self.extend_from_slice(&[0; 4]); + + // write the main body of the message + f(self)?; + + // now calculate the size of what we wrote and set the length value + let size = (self.len() - offset) as i32; + self[offset..(offset + 4)].copy_from_slice(&size.to_be_bytes()); + + Ok(()) + } +} diff --git a/sqlx-postgres/src/lib.rs b/sqlx-postgres/src/lib.rs index b2448209..1846a0e4 100644 --- a/sqlx-postgres/src/lib.rs +++ b/sqlx-postgres/src/lib.rs @@ -28,8 +28,8 @@ mod stream; mod column; mod connection; mod database; -// mod error; -// mod io; +mod error; +mod io; mod options; mod output; mod protocol; @@ -46,9 +46,10 @@ pub mod types; // mod mock; pub use column::PgColumn; -// pub use connection::PgConnection; +pub use connection::PgConnection; pub use database::Postgres; // pub use error::PgDatabaseError; +pub use error::PgClientError; pub use options::PgConnectOptions; pub use output::PgOutput; pub use query_result::PgQueryResult; diff --git a/sqlx-postgres/src/protocol.rs b/sqlx-postgres/src/protocol.rs index 3ef8d5de..9f6aebb7 100644 --- a/sqlx-postgres/src/protocol.rs +++ b/sqlx-postgres/src/protocol.rs @@ -1,3 +1,2 @@ -mod message; - -pub(crate) use message::{BackendMessage, BackendMessageType}; +pub(crate) mod backend; +pub(crate) mod frontend; diff --git a/sqlx-postgres/src/protocol/backend.rs b/sqlx-postgres/src/protocol/backend.rs new file mode 100644 index 00000000..695c5114 --- /dev/null +++ b/sqlx-postgres/src/protocol/backend.rs @@ -0,0 +1,7 @@ +mod auth; +mod message; +mod sasl; + +pub(crate) use auth::{Authentication, AuthenticationMd5Password}; +pub(crate) use message::{BackendMessage, BackendMessageType}; +pub(crate) use sasl::{AuthenticationSasl, AuthenticationSaslContinue, AuthenticationSaslFinal}; diff --git a/sqlx-postgres/src/protocol/backend/auth.rs b/sqlx-postgres/src/protocol/backend/auth.rs new file mode 100644 index 00000000..95f2c095 --- /dev/null +++ b/sqlx-postgres/src/protocol/backend/auth.rs @@ -0,0 +1,61 @@ +use bytes::{Buf, Bytes}; +use sqlx_core::io::Deserialize; +use sqlx_core::{Error, Result}; + +use crate::protocol::backend::{ + AuthenticationSasl, AuthenticationSaslContinue, AuthenticationSaslFinal, +}; +use crate::PgClientError; + +#[derive(Debug)] +pub(crate) enum Authentication { + /// The authentication exchange is successfully completed. + Ok, + + /// The frontend must now send a PasswordMessage containing the + /// password in clear-text form. + CleartextPassword, + + /// The frontend must now send a PasswordMessage containing the + /// password (with user name) encrypted via MD5. + Md5Password(AuthenticationMd5Password), + + /// The frontend must now initiate a SASL negotiation, + /// using one of the SASL mechanisms listed in the message. + Sasl(AuthenticationSasl), + + /// This message contains challenge data from the previous step of + /// SASL negotiation. + SaslContinue(AuthenticationSaslContinue), + + /// SASL authentication has completed with additional mechanism-specific + /// data for the client. + SaslFinal(AuthenticationSaslFinal), +} + +impl Deserialize<'_> for Authentication { + fn deserialize_with(mut buf: Bytes, _: ()) -> Result { + match buf.get_u32() { + 0 => Ok(Self::Ok), + 3 => Ok(Self::CleartextPassword), + + 5 => { + let mut salt = [0_u8; 4]; + buf.copy_to_slice(&mut salt); + + Ok(Self::Md5Password(AuthenticationMd5Password { salt })) + } + + 10 => AuthenticationSasl::deserialize(buf).map(Self::Sasl), + 11 => AuthenticationSaslContinue::deserialize(buf).map(Self::SaslContinue), + 12 => AuthenticationSaslFinal::deserialize(buf).map(Self::SaslFinal), + + ty => Err(Error::client(PgClientError::UnknownAuthenticationMethod(ty))), + } + } +} + +#[derive(Debug)] +pub(crate) struct AuthenticationMd5Password { + pub(crate) salt: [u8; 4], +} diff --git a/sqlx-postgres/src/protocol/message.rs b/sqlx-postgres/src/protocol/backend/message.rs similarity index 94% rename from sqlx-postgres/src/protocol/message.rs rename to sqlx-postgres/src/protocol/backend/message.rs index 458eb7e7..1e995b98 100644 --- a/sqlx-postgres/src/protocol/message.rs +++ b/sqlx-postgres/src/protocol/backend/message.rs @@ -1,6 +1,7 @@ use std::convert::TryFrom; use std::fmt::Debug; +use crate::PgClientError; use bytes::Bytes; use sqlx_core::io::Deserialize; use sqlx_core::{Error, Result}; @@ -66,7 +67,7 @@ impl TryFrom for BackendMessageType { b'c' => Self::CopyDone, _ => { - todo!("protocol unexpected data error") + return Err(Error::client(PgClientError::UnknownMessageType(ty))); } }) } @@ -74,7 +75,7 @@ impl TryFrom for BackendMessageType { #[derive(Debug)] pub(crate) struct BackendMessage { - pub(crate) r#type: BackendMessageType, + pub(crate) ty: BackendMessageType, pub(crate) contents: Bytes, } diff --git a/sqlx-postgres/src/protocol/backend/sasl.rs b/sqlx-postgres/src/protocol/backend/sasl.rs new file mode 100644 index 00000000..0d79495f --- /dev/null +++ b/sqlx-postgres/src/protocol/backend/sasl.rs @@ -0,0 +1,88 @@ +use std::convert::TryFrom; + +use bytes::Bytes; +use bytestring::ByteString; +use sqlx_core::io::Deserialize; +use sqlx_core::Result; + +#[derive(Debug)] +pub(crate) struct AuthenticationSasl(Bytes); + +impl Deserialize<'_> for AuthenticationSasl { + fn deserialize_with(buf: Bytes, _: ()) -> Result { + Ok(Self(buf)) + } +} + +#[derive(Debug)] +pub(crate) struct AuthenticationSaslContinue { + pub(crate) salt: Box<[u8]>, + pub(crate) iterations: u32, + pub(crate) nonce: ByteString, + pub(crate) message: ByteString, +} + +impl Deserialize<'_> for AuthenticationSaslContinue { + fn deserialize_with(buf: Bytes, _: ()) -> Result { + let mut iterations: u32 = 4096; + let mut salt = Vec::new(); + let mut nonce = Bytes::new(); + + // [Example] + // r=/z+giZiTxAH7r8sNAeHr7cvpqV3uo7G/bJBIJO3pjVM7t3ng,s=4UV68bIkC8f9/X8xH7aPhg==,i=4096 + + for item in buf.split(|b| *b == b',') { + let key = item[0]; + let value = &item[2..]; + + match key { + b'r' => { + nonce = buf.slice_ref(value); + } + + b'i' => { + iterations = atoi::atoi(value).unwrap_or(4096); + } + + b's' => { + // FIXME: raise proper protocol errors + salt = base64::decode(value).unwrap(); + } + + _ => {} + } + } + + Ok(Self { + iterations, + salt: salt.into_boxed_slice(), + + // FIXME: raise proper protocol errors + nonce: ByteString::try_from(nonce).unwrap(), + message: ByteString::try_from(buf).unwrap(), + }) + } +} + +#[derive(Debug)] +pub(crate) struct AuthenticationSaslFinal { + pub(crate) verifier: Box<[u8]>, +} + +impl Deserialize<'_> for AuthenticationSaslFinal { + fn deserialize_with(buf: Bytes, _: ()) -> Result { + let mut verifier = Vec::new(); + + for item in buf.split(|b| *b == b',') { + let key = item[0]; + let value = &item[2..]; + + if let b'v' = key { + // FIXME: raise proper protocol errors + verifier = base64::decode(value).unwrap(); + } + } + + Ok(Self { verifier: verifier.into_boxed_slice() }) + } +} diff --git a/sqlx-postgres/src/protocol/frontend.rs b/sqlx-postgres/src/protocol/frontend.rs new file mode 100644 index 00000000..96712d95 --- /dev/null +++ b/sqlx-postgres/src/protocol/frontend.rs @@ -0,0 +1,5 @@ +mod startup; +mod terminate; + +pub(crate) use startup::Startup; +pub(crate) use terminate::Terminate; diff --git a/sqlx-postgres/src/protocol/frontend/startup.rs b/sqlx-postgres/src/protocol/frontend/startup.rs new file mode 100644 index 00000000..6b8e5939 --- /dev/null +++ b/sqlx-postgres/src/protocol/frontend/startup.rs @@ -0,0 +1,55 @@ +use sqlx_core::io::Serialize; +use sqlx_core::Result; + +use crate::io::PgWriteExt; + +#[derive(Debug)] +pub(crate) struct Startup<'a>(pub(crate) &'a [(&'a str, Option<&'a str>)]); + +impl Serialize<'_> for Startup<'_> { + fn serialize_with(&self, buf: &mut Vec, _: ()) -> Result<()> { + buf.write_len_prefixed(|buf| { + // The protocol version number. The most significant 16 bits are the + // major version number (3 for the protocol described here). The least + // significant 16 bits are the minor version number (0 + // for the protocol described here) + buf.extend(&196_608_i32.to_be_bytes()); + + // For each startup parameter, write the name and value + // as NUL-terminated strings + for (name, value) in self.0 { + if let Some(value) = value { + write_startup_param(buf, name, value); + } + } + + // Followed by a trailing NUL + buf.push(0); + + Ok(()) + }) + } +} + +fn write_startup_param(buf: &mut Vec, name: &str, value: &str) { + buf.reserve(name.len() + value.len() + 2); + buf.extend(name.as_bytes()); + buf.push(0); + buf.extend(value.as_bytes()); + buf.push(0); +} + +#[cfg(test)] +mod tests { + use super::{Serialize, Startup}; + + #[test] + fn should_encode_startup() { + let mut buf = Vec::new(); + let m = Startup(&[("user", Some("postgres")), ("database", Some("postgres"))]); + + m.serialize(&mut buf).unwrap(); + + assert_eq!(buf, b"\0\0\0)\0\x03\0\0user\0postgres\0database\0postgres\0\0"); + } +} diff --git a/sqlx-postgres/src/protocol/frontend/terminate.rs b/sqlx-postgres/src/protocol/frontend/terminate.rs new file mode 100644 index 00000000..8c7fccdb --- /dev/null +++ b/sqlx-postgres/src/protocol/frontend/terminate.rs @@ -0,0 +1,32 @@ +use sqlx_core::io::Serialize; +use sqlx_core::Result; + +/// On receipt of this message, the backend closes the connection +/// and terminates. +#[derive(Debug)] +pub(crate) struct Terminate; + +impl Serialize<'_> for Terminate { + fn serialize_with(&self, buf: &mut Vec, _: ()) -> Result<()> { + buf.push(b'X'); + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use sqlx_core::io::Serialize; + + use super::Terminate; + + #[test] + fn should_serialize() -> anyhow::Result<()> { + let mut buf = Vec::new(); + Terminate.serialize(&mut buf)?; + + assert_eq!(&buf, &[b'X']); + + Ok(()) + } +} diff --git a/sqlx-postgres/src/stream.rs b/sqlx-postgres/src/stream.rs index 4bf3e5f3..ba9ee0b8 100644 --- a/sqlx-postgres/src/stream.rs +++ b/sqlx-postgres/src/stream.rs @@ -3,11 +3,12 @@ use std::fmt::Debug; use std::ops::{Deref, DerefMut}; use bytes::Buf; -use sqlx_core::io::{BufStream, Serialize}; +use sqlx_core::io::{BufStream, Serialize, Stream}; use sqlx_core::net::Stream as NetStream; use sqlx_core::{Result, Runtime}; -use crate::protocol::{BackendMessage, BackendMessageType}; +use crate::protocol::backend::{BackendMessage, BackendMessageType}; +use crate::protocol::frontend::Terminate; /// Reads and writes messages to and from the PostgreSQL database server. /// @@ -38,6 +39,10 @@ impl PgStream { where T: Serialize<'ser> + Debug, { + log::trace!("write > {:?}", message); + + message.serialize(self.stream.buffer())?; + Ok(()) } @@ -86,7 +91,7 @@ impl PgStream { Ok(None) } - _ => Ok(Some(BackendMessage { contents, r#type: ty })), + _ => Ok(Some(BackendMessage { contents, ty })), } } } @@ -159,3 +164,29 @@ macro_rules! read_message { $stream.read_message_async().await? }; } + +impl PgStream { + #[cfg(feature = "async")] + pub(crate) async fn close_async(&mut self) -> Result<()> + where + Rt: sqlx_core::Async, + { + self.write_message(&Terminate)?; + self.flush_async().await?; + self.shutdown_async().await?; + + Ok(()) + } + + #[cfg(feature = "blocking")] + pub(crate) fn close_blocking(&mut self) -> Result<()> + where + Rt: sqlx_core::blocking::Runtime, + { + self.write_message(&Terminate)?; + self.flush()?; + self.shutdown()?; + + Ok(()) + } +}