implement TLS support for Postgres

This commit is contained in:
Austin Bonander
2020-01-09 21:57:51 -08:00
parent 6c8fd949dd
commit 638852a2dd
12 changed files with 301 additions and 12 deletions

View File

@@ -49,6 +49,9 @@ jobs:
# -----------------------------------------------------
# Check that we build with TLS support (TODO: we need a postgres image with SSL certs to test)
- run: cargo build -p sqlx-core --no-default-features 'postgres macros uuid chrono tls'
- run: cargo test -p sqlx --no-default-features --features 'postgres macros uuid chrono'
env:
DATABASE_URL: postgres://postgres:postgres@localhost:${{ job.services.postgres.ports[5432] }}/postgres

View File

@@ -30,6 +30,7 @@ all-features = true
[features]
default = [ "macros" ]
macros = [ "sqlx-macros", "proc-macro-hack" ]
tls = ["sqlx-core/tls"]
# database
postgres = [ "sqlx-core/postgres", "sqlx-macros/postgres" ]
@@ -48,6 +49,7 @@ hex = "0.4.0"
[dev-dependencies]
anyhow = "1.0.26"
futures = "0.3.1"
env_logger = "0.7"
async-std = { version = "1.4.0", features = [ "attributes" ] }
dotenv = "0.15.0"

View File

@@ -20,8 +20,10 @@ default = []
unstable = []
postgres = [ "md-5", "sha2", "base64", "sha-1", "rand", "hmac" ]
mysql = [ "sha-1", "sha2", "generic-array", "num-bigint", "base64", "digest", "rand" ]
tls = ["async-native-tls"]
[dependencies]
async-native-tls = { version = "0.3", optional = true }
async-std = "1.4.0"
async-stream = { version = "0.2.0", default-features = false }
base64 = { version = "0.11.0", default-features = false, optional = true, features = [ "std" ] }

View File

@@ -44,6 +44,9 @@ pub enum Error {
/// [Pool::close] was called while we were waiting in [Pool::acquire].
PoolClosed,
/// An error occurred during a TLS upgrade.
TlsUpgrade(Box<dyn StdError + Send + Sync>),
Decode(DecodeError),
// TODO: Remove and replace with `#[non_exhaustive]` when possible
@@ -62,6 +65,8 @@ impl StdError for Error {
Error::Decode(DecodeError::Other(error)) => Some(&**error),
Error::TlsUpgrade(error) => Some(&**error),
_ => None,
}
}
@@ -100,6 +105,8 @@ impl Display for Error {
Error::PoolClosed => f.write_str("attempted to acquire a connection on a closed pool"),
Error::TlsUpgrade(ref err) => write!(f, "error during TLS upgrade: {}", err),
Error::__Nonexhaustive => unreachable!(),
}
}
@@ -140,6 +147,21 @@ impl From<ProtocolError<'_>> for Error {
}
}
#[cfg(feature = "tls")]
impl From<async_native_tls::Error> for Error {
#[inline]
fn from(err: async_native_tls::Error) -> Self {
Error::TlsUpgrade(err.into())
}
}
impl From<TlsError<'_>> for Error {
#[inline]
fn from(err: TlsError<'_>) -> Self {
Error::TlsUpgrade(err.args.to_string().into())
}
}
impl<T> From<T> for Error
where
T: 'static + DatabaseError,
@@ -189,6 +211,15 @@ macro_rules! protocol_err (
}
);
pub(crate) struct TlsError<'a> {
pub args: fmt::Arguments<'a>,
}
#[allow(unused_macros)]
macro_rules! tls_err {
($($args:tt)*) => { crate::error::TlsError { args: format_args!($($args)*)} };
}
#[allow(unused_macros)]
macro_rules! impl_fmt_error {
($err:ty) => {
@@ -212,3 +243,4 @@ macro_rules! impl_fmt_error {
}
};
}

View File

@@ -51,6 +51,12 @@ where
Ok(())
}
pub fn clear_bufs(&mut self) {
self.rbuf_rindex = 0;
self.rbuf_windex = 0;
self.wbuf.clear();
}
#[inline]
pub fn consume(&mut self, cnt: usize) {
self.rbuf_rindex += cnt;

View File

@@ -5,11 +5,14 @@ mod buf;
mod buf_mut;
mod byte_str;
mod tls;
pub use self::{
buf::{Buf, ToBuf},
buf_mut::BufMut,
buf_stream::BufStream,
byte_str::ByteStr,
tls::MaybeTlsStream
};
#[cfg(test)]

109
sqlx-core/src/io/tls.rs Normal file
View File

@@ -0,0 +1,109 @@
use std::io::{IoSlice, IoSliceMut};
use std::pin::Pin;
use std::task::{Context, Poll};
use async_std::io::{self, Read, Write};
use async_std::net::{Shutdown, TcpStream};
use crate::url::Url;
use self::Inner::*;
pub struct MaybeTlsStream {
inner: Inner,
}
enum Inner {
NotTls(TcpStream),
#[cfg(feature = "tls")]
Tls(async_native_tls::TlsStream<TcpStream>),
#[cfg(feature = "tls")]
Upgrading,
}
impl MaybeTlsStream {
pub async fn connect(url: &Url, default_port: u16) -> crate::Result<Self> {
let conn = TcpStream::connect((url.host(), url.port(default_port))).await?;
Ok(Self { inner: Inner::NotTls(conn) })
}
#[cfg(feature = "tls")]
pub async fn upgrade(&mut self, url: &Url, connector: async_native_tls::TlsConnector) -> crate::Result<()> {
let conn = match std::mem::replace(&mut self.inner, Upgrading) {
NotTls(conn) => conn,
Tls(_) => return Err(tls_err!("connection already upgraded").into()),
Upgrading => return Err(tls_err!("connection already failed to upgrade").into()),
};
self.inner = Tls(connector.connect(url.host(), conn).await?);
Ok(())
}
pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
match self.inner {
NotTls(ref conn) => conn.shutdown(how),
#[cfg(feature = "tls")]
Tls(ref conn) => conn.get_ref().shutdown(how),
#[cfg(feature = "tls")]
// connection already closed
Upgrading => Ok(()),
}
}
}
macro_rules! forward_pin (
($self:ident.$method:ident($($arg:ident),*)) => (
match &mut $self.inner {
NotTls(ref mut conn) => Pin::new(conn).$method($($arg),*),
#[cfg(feature = "tls")]
Tls(ref mut conn) => Pin::new(conn).$method($($arg),*),
#[cfg(feature = "tls")]
Upgrading => Err(io::Error::new(io::ErrorKind::Other, "connection broken; TLS upgrade failed")).into(),
}
)
);
impl Read for MaybeTlsStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
forward_pin!(self.poll_read(cx, buf))
}
fn poll_read_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context,
bufs: &mut [IoSliceMut],
) -> Poll<io::Result<usize>> {
forward_pin!(self.poll_read_vectored(cx, bufs))
}
}
impl Write for MaybeTlsStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8],
) -> Poll<io::Result<usize>> {
forward_pin!(self.poll_write(cx, buf))
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
forward_pin!(self.poll_flush(cx))
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
forward_pin!(self.poll_close(cx))
}
fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context,
bufs: &[IoSlice],
) -> Poll<io::Result<usize>> {
forward_pin!(self.poll_write_vectored(cx, bufs))
}
}

View File

@@ -1,6 +1,8 @@
use std::convert::TryInto;
use std::path::Path;
use async_std::net::{Shutdown, TcpStream};
use async_std::fs;
use async_std::net::Shutdown;
use byteorder::NetworkEndian;
use futures_core::future::BoxFuture;
@@ -11,12 +13,15 @@ use crate::postgres::protocol::{
self, hi, Authentication, Decode, Encode, Message, SaslInitialResponse, SaslResponse,
StatementId,
};
use crate::io::{Buf, BufStream, MaybeTlsStream};
use crate::postgres::PgError;
use crate::postgres::protocol::{self, Decode, Encode, Message, StatementId};
use crate::url::Url;
use crate::Result;
use hmac::{Hmac, Mac};
use rand::Rng;
use sha2::{Digest, Sha256};
use async_native_tls::Certificate;
/// An asynchronous connection to a [Postgres] database.
///
@@ -24,7 +29,7 @@ use sha2::{Digest, Sha256};
/// string, as documented at
/// <https://www.postgresql.org/docs/12/libpq-connect.html#LIBPQ-CONNSTRING>
pub struct PgConnection {
pub(super) stream: BufStream<TcpStream>,
pub(super) stream: BufStream<MaybeTlsStream>,
// Map of query to statement id
pub(super) statement_cache: StatementCache<StatementId>,
@@ -43,8 +48,43 @@ pub struct PgConnection {
}
impl PgConnection {
#[cfg(feature = "tls")]
async fn try_ssl(&mut self, url: &Url, invalid_certs: bool, invalid_hostnames: bool) -> crate::Result<bool> {
use async_native_tls::{TlsConnector, Certificate};
use std::env;
protocol::SslRequest::encode(self.stream.buffer_mut());
self.stream.flush().await?;
match self.stream.peek(1).await? {
Some(b"N") => return Ok(false),
Some(b"S") => (),
Some(other) => return Err(tls_err!("unexpected single-byte response: 0x{:02X}", other[0]).into()),
None => return Err(tls_err!("server unexpectedly closed connection").into())
}
let mut connector = TlsConnector::new()
.danger_accept_invalid_certs(invalid_certs)
.danger_accept_invalid_hostnames(invalid_hostnames);
if !invalid_certs {
match read_root_certificate(&url).await {
Ok(cert) => {
connector = connector.add_root_certificate(cert);
}
Err(e) => log::warn!("failed to read Postgres root certificate: {}", e)
}
}
self.stream.clear_bufs();
self.stream.stream.upgrade(url, connector).await?;
Ok(true)
}
// https://www.postgresql.org/docs/12/protocol-flow.html#id-1.10.5.7.3
async fn startup(&mut self, url: Url) -> Result<()> {
async fn startup(&mut self, url: &Url) -> Result<()> {
// Defaults to postgres@.../postgres
let username = url.username().unwrap_or("postgres");
let database = url.database().unwrap_or("postgres");
@@ -83,7 +123,7 @@ impl PgConnection {
protocol::PasswordMessage::ClearText(
url.password().unwrap_or_default(),
)
.encode(self.stream.buffer_mut());
.encode(self.stream.buffer_mut());
self.stream.flush().await?;
}
@@ -94,7 +134,7 @@ impl PgConnection {
user: username,
salt,
}
.encode(self.stream.buffer_mut());
.encode(self.stream.buffer_mut());
self.stream.flush().await?;
}
@@ -137,7 +177,7 @@ impl PgConnection {
"requires unimplemented authentication method: {:?}",
auth
)
.into());
.into());
}
}
}
@@ -240,7 +280,8 @@ impl PgConnection {
impl PgConnection {
pub(super) async fn open(url: Result<Url>) -> Result<Self> {
let url = url?;
let stream = TcpStream::connect((url.host(), url.port(5432))).await?;
let stream = MaybeTlsStream::connect(&url, 5432).await?;
let mut self_ = Self {
stream: BufStream::new(stream),
process_id: 0,
@@ -251,7 +292,36 @@ impl PgConnection {
ready: true,
};
self_.startup(url).await?;
let ssl_mode = url.get_param("sslmode").unwrap_or("prefer".into());
match &*ssl_mode {
// TODO: on "allow" retry with TLS if startup fails
"disable" | "allow" => (),
#[cfg(feature = "tls")]
"prefer" => { self_.try_ssl(&url, true, true).await?; },
#[cfg(not(feature = "tls"))]
"prefer" => log::info!("compiled without TLS, skipping upgrade"),
#[cfg(feature = "tls")]
"require" | "verify-ca" | "verify-full" => if !self_.try_ssl(
&url,
ssl_mode == "require", // false for both verify-ca and verify-full
ssl_mode != "verify-full" // false for only verify-full
).await? {
return Err(tls_err!("Postgres server does not support TLS").into())
}
#[cfg(not(feature = "tls"))]
"require" | "verify-ca" | "verify-full" => return Err(
tls_err!("sslmode {:?} unsupported; SQLx was compiled without `tls` feature",
ssl_mode).into()
),
_ => return Err(tls_err!("unknown `sslmode` value: {:?}", ssl_mode).into()),
}
self_.startup(&url).await?;
Ok(self_)
}
@@ -259,9 +329,9 @@ impl PgConnection {
impl Connection for PgConnection {
fn open<T>(url: T) -> BoxFuture<'static, Result<Self>>
where
T: TryInto<Url, Error = crate::Error>,
Self: Sized,
where
T: TryInto<Url, Error=crate::Error>,
Self: Sized,
{
Box::pin(PgConnection::open(url.try_into()))
}
@@ -271,6 +341,26 @@ impl Connection for PgConnection {
}
}
#[cfg(feature = "tls")]
async fn read_root_certificate(url: &Url) -> crate::Result<async_native_tls::Certificate> {
use std::env;
let root_cert_path = if let Some(path) = url.get_param("sslrootcert") {
path.into()
} else if let Ok(cert_path) = env::var("PGSSLROOTCERT"){
cert_path
} else if cfg!(windows) {
let appdata = env::var("APPDATA").map_err(|_| tls_err!("APPDATA not set"))?;
format!("{}\\postgresql\\root.crt", appdata)
} else {
let home = env::var("HOME").map_err(|_| tls_err!("HOME not set"))?;
format!("{}/.postgresql/root.crt", home)
};
let root_cert = async_std::fs::read(root_cert_path).await?;
Ok(async_native_tls::Certificate::from_pem(&root_cert)?)
}
static GS2_HEADER: &'static str = "n,,";
static CHANNEL_ATTR: &'static str = "c";
static USERNAME_ATTR: &'static str = "n";
@@ -354,7 +444,7 @@ async fn sasl_auth<T: AsRef<str>>(conn: &mut PgConnection, username: T, password
);
// AuthMessage := client-first-message-bare + "," + server-first-message + "," + client-final-message-without-proof
let auth_message = format!("{client_first_message_bare},{server_first_message},{client_final_message_wo_proof}",
let auth_message = format!("{client_first_message_bare},{server_first_message},{client_final_message_wo_proof}",
client_first_message_bare = client_first_message_bare,
server_first_message = server_first_message,
client_final_message_wo_proof = client_final_message_wo_proof);

View File

@@ -5,6 +5,7 @@
// the size of this module to exactly what is necessary.
#![allow(unused)]
// REQUESTS
mod bind;
mod cancel_request;
mod close;
@@ -16,6 +17,7 @@ mod parse;
mod password_message;
mod query;
mod sasl;
mod ssl_request;
mod startup_message;
mod statement;
mod sync;
@@ -32,11 +34,13 @@ pub use parse::Parse;
pub use password_message::PasswordMessage;
pub use query::Query;
pub use sasl::{hi, SaslInitialResponse, SaslResponse};
pub use ssl_request::SslRequest;
pub use startup_message::StartupMessage;
pub use statement::StatementId;
pub use sync::Sync;
pub use terminate::Terminate;
// RESPONSES
mod authentication;
mod backend_key_data;
mod command_complete;

View File

@@ -0,0 +1,25 @@
use crate::io::{Buf, BufMut};
use byteorder::NetworkEndian;
pub struct SslRequest;
impl SslRequest {
pub fn encode(buf: &mut Vec<u8>) {
// packet length: 8 bytes including self
buf.put_u32::<NetworkEndian>(8);
// 1234 in high 16 bits, 5679 in low 16
buf.put_u32::<NetworkEndian>(
(1234 << 16) | 5679,
);
}
}
#[test]
fn test_ssl_request() {
use crate::io::Buf;
let mut buf = Vec::new();
SslRequest::encode(&mut buf);
assert_eq!((&buf[..]).get_u32::<NetworkEndian>().unwrap(), 80877103);
}

View File

@@ -1,4 +1,5 @@
use std::convert::{TryFrom, TryInto};
use std::borrow::Cow;
pub struct Url(url::Url);
@@ -64,4 +65,14 @@ impl Url {
Some(database)
}
}
pub fn get_param(&self, key: &str) -> Option<Cow<str>> {
self.0.query_pairs().find_map(|(key_, val)| {
if key == key_ {
Some(val)
} else {
None
}
})
}
}

View File

@@ -66,5 +66,7 @@ async fn it_remains_stable_issue_30() -> anyhow::Result<()> {
}
async fn connect() -> anyhow::Result<PgConnection> {
let _ = dotenv::dotenv();
let _ = env_logger::try_init();
Ok(PgConnection::open(dotenv::var("DATABASE_URL")?).await?)
}