tls: update tls module for postgres

This commit is contained in:
Ryan Leckey
2020-03-01 23:08:20 -08:00
parent 370ad81b8e
commit 7fbc26de05
7 changed files with 146 additions and 63 deletions

View File

@@ -44,8 +44,9 @@ pub enum Error {
/// [Pool::close] was called while we were waiting in [Pool::acquire].
PoolClosed,
/// An error occurred during a TLS upgrade.
TlsUpgrade(Box<dyn StdError + Send + Sync>),
/// An error occurred while attempting to setup TLS.
/// This should only be returned from an explicit ask for TLS.
Tls(Box<dyn StdError + Send + Sync>),
/// An error occurred decoding data received from the database.
Decode(Box<dyn StdError + Send + Sync>),
@@ -67,7 +68,7 @@ impl StdError for Error {
Error::UrlParse(error) => Some(error),
Error::PoolTimedOut(Some(error)) => Some(&**error),
Error::Decode(error) => Some(&**error),
Error::TlsUpgrade(error) => Some(&**error),
Error::Tls(error) => Some(&**error),
_ => None,
}
@@ -111,7 +112,7 @@ impl Display for Error {
Error::PoolClosed => f.write_str("attempted to acquire a connection on a closed pool"),
Error::TlsUpgrade(ref err) => write!(f, "error during TLS upgrade: {}", err),
Error::Tls(ref err) => write!(f, "error during TLS upgrade: {}", err),
}
}
}
@@ -149,14 +150,14 @@ impl From<ProtocolError<'_>> for Error {
impl From<async_native_tls::Error> for Error {
#[inline]
fn from(err: async_native_tls::Error) -> Self {
Error::TlsUpgrade(err.into())
Error::Tls(err.into())
}
}
impl From<TlsError<'_>> for Error {
#[inline]
fn from(err: TlsError<'_>) -> Self {
Error::TlsUpgrade(err.args.to_string().into())
Error::Tls(err.args.to_string().into())
}
}

View File

@@ -485,9 +485,9 @@ impl MySqlConnection {
// On connect, server immediately sends the handshake
let mut handshake = self_.receive_handshake(&url).await?;
let ca_file = url.get_param("ssl-ca");
let ca_file = url.param("ssl-ca");
let ssl_mode = url.get_param("ssl-mode").unwrap_or(
let ssl_mode = url.param("ssl-mode").unwrap_or(
if ca_file.is_some() {
"VERIFY_CA"
} else {

View File

@@ -1,3 +1,4 @@
use std::borrow::Cow;
use std::collections::HashMap;
use std::convert::TryInto;
use std::ops::Range;
@@ -18,12 +19,10 @@ use crate::postgres::protocol::{
};
use crate::postgres::sasl;
use crate::postgres::stream::PgStream;
use crate::postgres::{PgError, PgTypeInfo};
use crate::postgres::{tls, PgError, PgTypeInfo};
use crate::url::Url;
use crate::{Error, Executor, Postgres};
// TODO: TLS
/// An asynchronous connection to a [Postgres][super::Postgres] database.
///
/// The connection string expected by [Connect::connect] should be a PostgreSQL connection
@@ -237,6 +236,7 @@ impl PgConnection {
let url = url?;
let mut stream = PgStream::new(&url).await?;
tls::request_if_needed(&mut stream, &url).await?;
startup(&mut stream, &url).await?;
Ok(Self {

View File

@@ -18,7 +18,7 @@ mod protocol;
mod row;
mod sasl;
mod stream;
// mod tls;
mod tls;
mod types;
/// An alias for [`Pool`][crate::Pool], specialized for **Postgres**.

View File

@@ -1,11 +1,13 @@
use byteorder::NetworkEndian;
use crate::io::BufMut;
use crate::postgres::protocol::Encode;
#[derive(Debug)]
pub struct SslRequest;
impl SslRequest {
pub fn encode(buf: &mut Vec<u8>) {
impl Encode for SslRequest {
fn encode(&self, buf: &mut Vec<u8>) {
// packet length: 8 bytes including self
buf.put_u32::<NetworkEndian>(8);
// 1234 in high 16 bits, 5679 in low 16
@@ -15,6 +17,7 @@ impl SslRequest {
#[test]
fn test_ssl_request() {
use crate::encode::Encode;
use crate::io::Buf;
let mut buf = Vec::new();

View File

@@ -1,66 +1,145 @@
use crate::postgres::protocol::SslRequest;
use crate::postgres::stream::PgStream;
use crate::postgres::PgConnection;
use crate::url::Url;
use std::borrow::Cow;
use std::fs::read;
impl PgConnection {
#[cfg(feature = "tls")]
pub(super) async fn try_ssl(
&mut self,
url: &Url,
invalid_certs: bool,
invalid_hostnames: bool,
) -> crate::Result<bool> {
use async_native_tls::TlsConnector;
SslRequest::encode(self.stream.buffer_mut());
self.stream.flush().await?;
match self.stream.peek(1).await? {
Some(b"N") => return Ok(false),
Some(b"S") => (),
Some(other) => {
return Err(tls_err!("unexpected single-byte response: 0x{:02X}", other[0]).into())
}
None => return Err(tls_err!("server unexpectedly closed connection").into()),
pub(crate) async fn request_if_needed(stream: &mut PgStream, url: &Url) -> crate::Result<()> {
// https://www.postgresql.org/docs/12/libpq-ssl.html#LIBPQ-SSL-SSLMODE-STATEMENTS
match url.param("sslmode").as_deref() {
Some("disable") | Some("allow") => {
// Do nothing
}
let mut connector = TlsConnector::new()
.danger_accept_invalid_certs(invalid_certs)
.danger_accept_invalid_hostnames(invalid_hostnames);
if !invalid_certs {
match read_root_certificate(&url).await {
Ok(cert) => {
connector = connector.add_root_certificate(cert);
}
Err(e) => log::warn!("failed to read Postgres root certificate: {}", e),
#[cfg(feature = "tls")]
Some("prefer") | None => {
// We default to [prefer] if TLS is compiled in
if !try_upgrade(stream, url, false, false).await? {
// TLS upgrade failed; fall back to a normal connection
}
}
self.stream.clear_bufs();
self.stream.stream.upgrade(url, connector).await?;
#[cfg(not(feature = "tls"))]
None => {
// The user neither explicitly enabled TLS in the connection string
// nor did they turn the `tls` feature on
Ok(true)
// Do nothing
}
#[cfg(feature = "tls")]
Some(mode @ "require") | Some(mode @ "verify-ca") | Some(mode @ "verify-full") => {
if !try_upgrade(
stream,
url,
// false for both verify-ca and verify-full
mode == "require",
// false for only verify-full
mode != "verify-full",
)
.await?
{
return Err(tls_err!("server does not support TLS").into());
}
}
#[cfg(not(feature = "tls"))]
Some(mode @ "prefer")
| Some(mode @ "require")
| Some(mode @ "verify-ca")
| Some(mode @ "verify-full") => {
return Err(tls_err!(
"sslmode {:?} unsupported; SQLx was compiled without `tls` feature",
mode
)
.into());
}
Some(mode) => {
return Err(tls_err!("unknown `sslmode` value: {:?}", mode).into());
}
}
Ok(())
}
#[cfg(feature = "tls")]
async fn read_root_certificate(url: &Url) -> crate::Result<async_native_tls::Certificate> {
async fn try_upgrade(
stream: &mut PgStream,
url: &Url,
accept_invalid_certs: bool,
accept_invalid_host_names: bool,
) -> crate::Result<bool> {
use async_native_tls::TlsConnector;
stream.write(SslRequest);
stream.flush().await?;
// The server then responds with a single byte containing S or N,
// indicating that it is willing or unwilling to perform SSL, respectively.
let ind = stream.stream.peek(1).await?[0];
stream.stream.consume(1);
match ind {
b'S' => {
// The server is ready and willing to accept an SSL connection
}
b'N' => {
// The server is _unwilling_ to perform SSL
return Ok(false);
}
other => {
return Err(tls_err!("unexpected response from SSLRequest: 0x{:02X}", other).into());
}
}
let mut connector = TlsConnector::new()
.danger_accept_invalid_certs(accept_invalid_certs)
.danger_accept_invalid_hostnames(accept_invalid_host_names);
if !accept_invalid_certs {
// Try to read in the root certificate for postgres using several
// standard methods (used by psql and libpq)
if let Some(cert) = read_root_certificate(&url).await? {
connector = connector.add_root_certificate(cert);
}
}
stream.stream.upgrade(url, connector).await?;
Ok(true)
}
#[cfg(feature = "tls")]
async fn read_root_certificate(url: &Url) -> crate::Result<Option<async_native_tls::Certificate>> {
use crate::runtime::fs;
use std::env;
let root_cert_path = if let Some(path) = url.get_param("sslrootcert") {
path.into()
} else if let Ok(cert_path) = env::var("PGSSLROOTCERT") {
cert_path
} else if cfg!(windows) {
let appdata = env::var("APPDATA").map_err(|_| tls_err!("APPDATA not set"))?;
format!("{}\\postgresql\\root.crt", appdata)
} else {
let home = env::var("HOME").map_err(|_| tls_err!("HOME not set"))?;
format!("{}/.postgresql/root.crt", home)
};
let mut data = None;
let root_cert = crate::runtime::fs::read(root_cert_path).await?;
Ok(async_native_tls::Certificate::from_pem(&root_cert)?)
if let Some(path) = url
.param("sslrootcert")
.or_else(|| env::var("PGSSLROOTCERT").ok().map(Into::into))
{
data = Some(fs::read(&*path).await?);
} else if cfg!(windows) {
if let Ok(app_data) = env::var("APPDATA") {
let path = format!("{}\\postgresql\\root.crt", app_data);
data = fs::read(path).await.ok();
}
} else {
if let Ok(home) = env::var("HOME") {
let path = format!("{}/.postgresql/root.crt", home);
data = fs::read(path).await.ok();
}
}
data.map(|data| async_native_tls::Certificate::from_pem(&data))
.transpose()
.map_err(Into::into)
}

View File

@@ -78,7 +78,7 @@ impl Url {
}
}
pub fn get_param(&self, key: &str) -> Option<Cow<str>> {
pub fn param(&self, key: &str) -> Option<Cow<str>> {
self.0
.query_pairs()
.find_map(|(key_, val)| if key == key_ { Some(val) } else { None })