feat(postgres): stub start-up flow

This commit is contained in:
Ryan Leckey 2021-03-20 00:16:20 -07:00
parent 424d4b7aa1
commit 1eb1cd3ea9
No known key found for this signature in database
GPG Key ID: F8AA68C235AB08C9
18 changed files with 552 additions and 15 deletions

View File

@ -38,6 +38,7 @@ futures-io = { version = "0.3", optional = true }
bytes = "1.0"
memchr = "2.3"
bitflags = "1.2"
base64 = "0.13.0"
[dev-dependencies]
sqlx-core = { version = "0.6.0-pre", path = "../sqlx-core", features = ["_mock"] }

View File

@ -2,10 +2,13 @@ use std::fmt::{self, Debug, Formatter};
#[cfg(feature = "async")]
use futures_util::future::{BoxFuture, FutureExt, TryFutureExt};
use sqlx_core::net::Stream as NetStream;
use sqlx_core::{Close, Connect, Connection, Runtime};
use crate::stream::PgStream;
use crate::Postgres;
use crate::{PgConnectOptions, Postgres};
mod connect;
/// A single connection (also known as a session) to a
/// PostgreSQL database server.
@ -64,10 +67,10 @@ impl<Rt: Runtime> Connection<Rt> for PgConnection<Rt> {
}
impl<Rt: Runtime> Connect<Rt> for PgConnection<Rt> {
type Options = PostgresConnectOptions;
type Options = PgConnectOptions;
#[cfg(feature = "async")]
fn connect_with(options: &PostgresConnectOptions) -> BoxFuture<'_, sqlx_core::Result<Self>>
fn connect_with(options: &PgConnectOptions) -> BoxFuture<'_, sqlx_core::Result<Self>>
where
Self: Sized,
Rt: sqlx_core::Async,
@ -82,6 +85,52 @@ impl<Rt: Runtime> Close<Rt> for PgConnection<Rt> {
where
Rt: sqlx_core::Async,
{
todo!()
Box::pin(async move {
self.stream.close_async().await?;
Ok(())
})
}
}
#[cfg(feature = "blocking")]
mod blocking {
use sqlx_core::blocking::{Close, Connect, Connection, Runtime};
use super::{PgConnectOptions, PgConnection, Postgres};
impl<Rt: Runtime> Connection<Rt> for PgConnection<Rt> {
#[inline]
fn ping(&mut self) -> sqlx_core::Result<()> {
todo!()
}
fn describe<'x, 'e, 'q>(
&'e mut self,
query: &'q str,
) -> sqlx_core::Result<sqlx_core::Describe<Postgres>>
where
'e: 'x,
'q: 'x,
{
todo!()
}
}
impl<Rt: Runtime> Connect<Rt> for PgConnection<Rt> {
#[inline]
fn connect_with(options: &PgConnectOptions) -> sqlx_core::Result<Self>
where
Self: Sized,
{
Self::connect_blocking(options)
}
}
impl<Rt: Runtime> Close<Rt> for PgConnection<Rt> {
#[inline]
fn close(mut self) -> sqlx_core::Result<()> {
self.stream.close_blocking()
}
}
}

View File

@ -0,0 +1,129 @@
//! Implements start-up flow.
//!
//! To begin a session, a frontend opens a connection to the server
//! and sends a startup message.
//!
//! The server then sends an appropriate authentication request message, to
//! which the frontend must reply with an appropriate authentication
//! response message.
//!
//! The authentication cycle ends with the server either rejecting
//! the connection attempt (ErrorResponse), or sending AuthenticationOk.
//!
//! <https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.5.7.3>
use sqlx_core::net::Stream as NetStream;
use sqlx_core::{Error, Result, Runtime};
use crate::protocol::backend::{Authentication, BackendMessage, BackendMessageType};
use crate::protocol::frontend::Startup;
use crate::{PgClientError, PgConnectOptions, PgConnection};
impl<Rt: Runtime> PgConnection<Rt> {
fn write_startup_message(&mut self, options: &PgConnectOptions) -> Result<()> {
let params = vec![
("user", options.get_username()),
("database", options.get_database()),
("application_name", options.get_application_name()),
// sets the text display format for date and time values
// as well as the rules for interpreting ambiguous date input values
("DateStyle", Some("ISO, MDY")),
// sets the client-side encoding (charset)
// NOTE: this must not be changed, too much in the driver depends on this being set to UTF-8
("client_encoding", Some("UTF8")),
// sets the timezone for displaying and interpreting time stamps
// NOTE: this is only used to assume timestamptz values are in UTC
("TimeZone", Some("UTC")),
];
self.stream.write_message(&Startup(&params))
}
fn handle_startup_response(
&mut self,
options: &PgConnectOptions,
message: BackendMessage,
) -> Result<bool> {
match message.ty {
BackendMessageType::Authentication => match message.deserialize()? {
Authentication::Ok => {
return Ok(true);
}
Authentication::Md5Password(_) => {
todo!("md5")
}
Authentication::CleartextPassword => {
todo!("cleartext")
}
Authentication::Sasl(_) => todo!("sasl"),
Authentication::SaslContinue(_) => todo!("sasl continue"),
Authentication::SaslFinal(_) => todo!("sasl final"),
},
ty => {
return Err(Error::client(PgClientError::UnexpectedMessageType {
ty: ty as u8,
context: "starting up",
}));
}
}
}
}
macro_rules! impl_connect {
(@blocking @new $options:ident) => {
NetStream::connect($options.address.as_ref())?
};
(@new $options:ident) => {
NetStream::connect_async($options.address.as_ref()).await?
};
($(@$blocking:ident)? $options:ident) => {{
// open a network stream to the database server
let stream = impl_connect!($(@$blocking)? @new $options);
// construct a <PgConnection> around the network stream
// wraps the stream in a <BufStream> to buffer read and write
let mut self_ = Self::new(stream);
// to begin a session, a frontend should send a startup message
// this is built up of various startup parameters that control the connection
self_.write_startup_message($options)?;
// the server then uses this information and the contents of
// its configuration files (such as pg_hba.conf) to determine whether the connection is
// provisionally acceptable, and what additional
// authentication is required (if any).
loop {
let message = read_message!($(@$blocking)? self_.stream);
if self_.handle_startup_response($options, message)? {
// complete, successful authentication
break;
}
}
Ok(self_)
}};
}
impl<Rt: Runtime> PgConnection<Rt> {
#[cfg(feature = "async")]
pub(crate) async fn connect_async(options: &PgConnectOptions) -> Result<Self>
where
Rt: sqlx_core::Async,
{
impl_connect!(options)
}
#[cfg(feature = "blocking")]
pub(crate) fn connect_blocking(options: &PgConnectOptions) -> Result<Self>
where
Rt: sqlx_core::blocking::Runtime,
{
impl_connect!(@blocking options)
}
}

View File

@ -0,0 +1,4 @@
mod client;
mod database;
pub use client::PgClientError;

View File

@ -0,0 +1,40 @@
use crate::protocol::backend::BackendMessageType;
use sqlx_core::ClientError;
use std::error::Error as StdError;
use std::fmt::{self, Display, Formatter};
use std::str::Utf8Error;
#[derive(Debug)]
#[non_exhaustive]
pub enum PgClientError {
// attempting to interpret data from postgres as UTF-8, when it should
// be UTF-8, but for some reason (data corruption?) it is not
NotUtf8(Utf8Error),
UnknownAuthenticationMethod(u32),
UnknownMessageType(u8),
UnexpectedMessageType { ty: u8, context: &'static str },
}
impl Display for PgClientError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
Self::NotUtf8(source) => write!(f, "{}", source),
Self::UnknownAuthenticationMethod(method) => {
write!(f, "unknown authentication method: {}", method)
}
Self::UnknownMessageType(ty) => {
write!(f, "unknown protocol message type: '{}' ({})", *ty as char, *ty)
}
Self::UnexpectedMessageType { ty, context } => {
write!(f, "unexpected message {:?} '{}' while {}", ty, (*ty as u8 as char), context)
}
}
}
}
impl StdError for PgClientError {}
impl ClientError for PgClientError {}

View File

3
sqlx-postgres/src/io.rs Normal file
View File

@ -0,0 +1,3 @@
mod write;
pub(crate) use write::PgWriteExt;

View File

@ -0,0 +1,31 @@
use sqlx_core::io::WriteExt;
use sqlx_core::Result;
pub trait PgWriteExt: WriteExt {
fn write_len_prefixed<F>(&mut self, f: F) -> Result<()>
where
F: FnOnce(&mut Vec<u8>) -> Result<()>;
}
impl PgWriteExt for Vec<u8> {
/// Writes a length-prefixed message, this is used when encoding nearly
/// all messages as postgres wants us to send the length of the
/// often-variable-sized messages up front.
fn write_len_prefixed<F>(&mut self, f: F) -> Result<()>
where
F: FnOnce(&mut Vec<u8>) -> Result<()>,
{
// reserve space to write the prefixed length
let offset = self.len();
self.extend_from_slice(&[0; 4]);
// write the main body of the message
f(self)?;
// now calculate the size of what we wrote and set the length value
let size = (self.len() - offset) as i32;
self[offset..(offset + 4)].copy_from_slice(&size.to_be_bytes());
Ok(())
}
}

View File

@ -28,8 +28,8 @@ mod stream;
mod column;
mod connection;
mod database;
// mod error;
// mod io;
mod error;
mod io;
mod options;
mod output;
mod protocol;
@ -46,9 +46,10 @@ pub mod types;
// mod mock;
pub use column::PgColumn;
// pub use connection::PgConnection;
pub use connection::PgConnection;
pub use database::Postgres;
// pub use error::PgDatabaseError;
pub use error::PgClientError;
pub use options::PgConnectOptions;
pub use output::PgOutput;
pub use query_result::PgQueryResult;

View File

@ -1,3 +1,2 @@
mod message;
pub(crate) use message::{BackendMessage, BackendMessageType};
pub(crate) mod backend;
pub(crate) mod frontend;

View File

@ -0,0 +1,7 @@
mod auth;
mod message;
mod sasl;
pub(crate) use auth::{Authentication, AuthenticationMd5Password};
pub(crate) use message::{BackendMessage, BackendMessageType};
pub(crate) use sasl::{AuthenticationSasl, AuthenticationSaslContinue, AuthenticationSaslFinal};

View File

@ -0,0 +1,61 @@
use bytes::{Buf, Bytes};
use sqlx_core::io::Deserialize;
use sqlx_core::{Error, Result};
use crate::protocol::backend::{
AuthenticationSasl, AuthenticationSaslContinue, AuthenticationSaslFinal,
};
use crate::PgClientError;
#[derive(Debug)]
pub(crate) enum Authentication {
/// The authentication exchange is successfully completed.
Ok,
/// The frontend must now send a PasswordMessage containing the
/// password in clear-text form.
CleartextPassword,
/// The frontend must now send a PasswordMessage containing the
/// password (with user name) encrypted via MD5.
Md5Password(AuthenticationMd5Password),
/// The frontend must now initiate a SASL negotiation,
/// using one of the SASL mechanisms listed in the message.
Sasl(AuthenticationSasl),
/// This message contains challenge data from the previous step of
/// SASL negotiation.
SaslContinue(AuthenticationSaslContinue),
/// SASL authentication has completed with additional mechanism-specific
/// data for the client.
SaslFinal(AuthenticationSaslFinal),
}
impl Deserialize<'_> for Authentication {
fn deserialize_with(mut buf: Bytes, _: ()) -> Result<Self> {
match buf.get_u32() {
0 => Ok(Self::Ok),
3 => Ok(Self::CleartextPassword),
5 => {
let mut salt = [0_u8; 4];
buf.copy_to_slice(&mut salt);
Ok(Self::Md5Password(AuthenticationMd5Password { salt }))
}
10 => AuthenticationSasl::deserialize(buf).map(Self::Sasl),
11 => AuthenticationSaslContinue::deserialize(buf).map(Self::SaslContinue),
12 => AuthenticationSaslFinal::deserialize(buf).map(Self::SaslFinal),
ty => Err(Error::client(PgClientError::UnknownAuthenticationMethod(ty))),
}
}
}
#[derive(Debug)]
pub(crate) struct AuthenticationMd5Password {
pub(crate) salt: [u8; 4],
}

View File

@ -1,6 +1,7 @@
use std::convert::TryFrom;
use std::fmt::Debug;
use crate::PgClientError;
use bytes::Bytes;
use sqlx_core::io::Deserialize;
use sqlx_core::{Error, Result};
@ -66,7 +67,7 @@ impl TryFrom<u8> for BackendMessageType {
b'c' => Self::CopyDone,
_ => {
todo!("protocol unexpected data error")
return Err(Error::client(PgClientError::UnknownMessageType(ty)));
}
})
}
@ -74,7 +75,7 @@ impl TryFrom<u8> for BackendMessageType {
#[derive(Debug)]
pub(crate) struct BackendMessage {
pub(crate) r#type: BackendMessageType,
pub(crate) ty: BackendMessageType,
pub(crate) contents: Bytes,
}

View File

@ -0,0 +1,88 @@
use std::convert::TryFrom;
use bytes::Bytes;
use bytestring::ByteString;
use sqlx_core::io::Deserialize;
use sqlx_core::Result;
#[derive(Debug)]
pub(crate) struct AuthenticationSasl(Bytes);
impl Deserialize<'_> for AuthenticationSasl {
fn deserialize_with(buf: Bytes, _: ()) -> Result<Self> {
Ok(Self(buf))
}
}
#[derive(Debug)]
pub(crate) struct AuthenticationSaslContinue {
pub(crate) salt: Box<[u8]>,
pub(crate) iterations: u32,
pub(crate) nonce: ByteString,
pub(crate) message: ByteString,
}
impl Deserialize<'_> for AuthenticationSaslContinue {
fn deserialize_with(buf: Bytes, _: ()) -> Result<Self> {
let mut iterations: u32 = 4096;
let mut salt = Vec::new();
let mut nonce = Bytes::new();
// [Example]
// r=/z+giZiTxAH7r8sNAeHr7cvpqV3uo7G/bJBIJO3pjVM7t3ng,s=4UV68bIkC8f9/X8xH7aPhg==,i=4096
for item in buf.split(|b| *b == b',') {
let key = item[0];
let value = &item[2..];
match key {
b'r' => {
nonce = buf.slice_ref(value);
}
b'i' => {
iterations = atoi::atoi(value).unwrap_or(4096);
}
b's' => {
// FIXME: raise proper protocol errors
salt = base64::decode(value).unwrap();
}
_ => {}
}
}
Ok(Self {
iterations,
salt: salt.into_boxed_slice(),
// FIXME: raise proper protocol errors
nonce: ByteString::try_from(nonce).unwrap(),
message: ByteString::try_from(buf).unwrap(),
})
}
}
#[derive(Debug)]
pub(crate) struct AuthenticationSaslFinal {
pub(crate) verifier: Box<[u8]>,
}
impl Deserialize<'_> for AuthenticationSaslFinal {
fn deserialize_with(buf: Bytes, _: ()) -> Result<Self> {
let mut verifier = Vec::new();
for item in buf.split(|b| *b == b',') {
let key = item[0];
let value = &item[2..];
if let b'v' = key {
// FIXME: raise proper protocol errors
verifier = base64::decode(value).unwrap();
}
}
Ok(Self { verifier: verifier.into_boxed_slice() })
}
}

View File

@ -0,0 +1,5 @@
mod startup;
mod terminate;
pub(crate) use startup::Startup;
pub(crate) use terminate::Terminate;

View File

@ -0,0 +1,55 @@
use sqlx_core::io::Serialize;
use sqlx_core::Result;
use crate::io::PgWriteExt;
#[derive(Debug)]
pub(crate) struct Startup<'a>(pub(crate) &'a [(&'a str, Option<&'a str>)]);
impl Serialize<'_> for Startup<'_> {
fn serialize_with(&self, buf: &mut Vec<u8>, _: ()) -> Result<()> {
buf.write_len_prefixed(|buf| {
// The protocol version number. The most significant 16 bits are the
// major version number (3 for the protocol described here). The least
// significant 16 bits are the minor version number (0
// for the protocol described here)
buf.extend(&196_608_i32.to_be_bytes());
// For each startup parameter, write the name and value
// as NUL-terminated strings
for (name, value) in self.0 {
if let Some(value) = value {
write_startup_param(buf, name, value);
}
}
// Followed by a trailing NUL
buf.push(0);
Ok(())
})
}
}
fn write_startup_param(buf: &mut Vec<u8>, name: &str, value: &str) {
buf.reserve(name.len() + value.len() + 2);
buf.extend(name.as_bytes());
buf.push(0);
buf.extend(value.as_bytes());
buf.push(0);
}
#[cfg(test)]
mod tests {
use super::{Serialize, Startup};
#[test]
fn should_encode_startup() {
let mut buf = Vec::new();
let m = Startup(&[("user", Some("postgres")), ("database", Some("postgres"))]);
m.serialize(&mut buf).unwrap();
assert_eq!(buf, b"\0\0\0)\0\x03\0\0user\0postgres\0database\0postgres\0\0");
}
}

View File

@ -0,0 +1,32 @@
use sqlx_core::io::Serialize;
use sqlx_core::Result;
/// On receipt of this message, the backend closes the connection
/// and terminates.
#[derive(Debug)]
pub(crate) struct Terminate;
impl Serialize<'_> for Terminate {
fn serialize_with(&self, buf: &mut Vec<u8>, _: ()) -> Result<()> {
buf.push(b'X');
Ok(())
}
}
#[cfg(test)]
mod tests {
use sqlx_core::io::Serialize;
use super::Terminate;
#[test]
fn should_serialize() -> anyhow::Result<()> {
let mut buf = Vec::new();
Terminate.serialize(&mut buf)?;
assert_eq!(&buf, &[b'X']);
Ok(())
}
}

View File

@ -3,11 +3,12 @@ use std::fmt::Debug;
use std::ops::{Deref, DerefMut};
use bytes::Buf;
use sqlx_core::io::{BufStream, Serialize};
use sqlx_core::io::{BufStream, Serialize, Stream};
use sqlx_core::net::Stream as NetStream;
use sqlx_core::{Result, Runtime};
use crate::protocol::{BackendMessage, BackendMessageType};
use crate::protocol::backend::{BackendMessage, BackendMessageType};
use crate::protocol::frontend::Terminate;
/// Reads and writes messages to and from the PostgreSQL database server.
///
@ -38,6 +39,10 @@ impl<Rt: Runtime> PgStream<Rt> {
where
T: Serialize<'ser> + Debug,
{
log::trace!("write > {:?}", message);
message.serialize(self.stream.buffer())?;
Ok(())
}
@ -86,7 +91,7 @@ impl<Rt: Runtime> PgStream<Rt> {
Ok(None)
}
_ => Ok(Some(BackendMessage { contents, r#type: ty })),
_ => Ok(Some(BackendMessage { contents, ty })),
}
}
}
@ -159,3 +164,29 @@ macro_rules! read_message {
$stream.read_message_async().await?
};
}
impl<Rt: Runtime> PgStream<Rt> {
#[cfg(feature = "async")]
pub(crate) async fn close_async(&mut self) -> Result<()>
where
Rt: sqlx_core::Async,
{
self.write_message(&Terminate)?;
self.flush_async().await?;
self.shutdown_async().await?;
Ok(())
}
#[cfg(feature = "blocking")]
pub(crate) fn close_blocking(&mut self) -> Result<()>
where
Rt: sqlx_core::blocking::Runtime,
{
self.write_message(&Terminate)?;
self.flush()?;
self.shutdown()?;
Ok(())
}
}