mirror of
https://github.com/launchbadge/sqlx.git
synced 2026-04-30 22:14:23 +00:00
tls: update tls module for postgres
This commit is contained in:
@@ -44,8 +44,9 @@ pub enum Error {
|
||||
/// [Pool::close] was called while we were waiting in [Pool::acquire].
|
||||
PoolClosed,
|
||||
|
||||
/// An error occurred during a TLS upgrade.
|
||||
TlsUpgrade(Box<dyn StdError + Send + Sync>),
|
||||
/// An error occurred while attempting to setup TLS.
|
||||
/// This should only be returned from an explicit ask for TLS.
|
||||
Tls(Box<dyn StdError + Send + Sync>),
|
||||
|
||||
/// An error occurred decoding data received from the database.
|
||||
Decode(Box<dyn StdError + Send + Sync>),
|
||||
@@ -67,7 +68,7 @@ impl StdError for Error {
|
||||
Error::UrlParse(error) => Some(error),
|
||||
Error::PoolTimedOut(Some(error)) => Some(&**error),
|
||||
Error::Decode(error) => Some(&**error),
|
||||
Error::TlsUpgrade(error) => Some(&**error),
|
||||
Error::Tls(error) => Some(&**error),
|
||||
|
||||
_ => None,
|
||||
}
|
||||
@@ -111,7 +112,7 @@ impl Display for Error {
|
||||
|
||||
Error::PoolClosed => f.write_str("attempted to acquire a connection on a closed pool"),
|
||||
|
||||
Error::TlsUpgrade(ref err) => write!(f, "error during TLS upgrade: {}", err),
|
||||
Error::Tls(ref err) => write!(f, "error during TLS upgrade: {}", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -149,14 +150,14 @@ impl From<ProtocolError<'_>> for Error {
|
||||
impl From<async_native_tls::Error> for Error {
|
||||
#[inline]
|
||||
fn from(err: async_native_tls::Error) -> Self {
|
||||
Error::TlsUpgrade(err.into())
|
||||
Error::Tls(err.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<TlsError<'_>> for Error {
|
||||
#[inline]
|
||||
fn from(err: TlsError<'_>) -> Self {
|
||||
Error::TlsUpgrade(err.args.to_string().into())
|
||||
Error::Tls(err.args.to_string().into())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -485,9 +485,9 @@ impl MySqlConnection {
|
||||
// On connect, server immediately sends the handshake
|
||||
let mut handshake = self_.receive_handshake(&url).await?;
|
||||
|
||||
let ca_file = url.get_param("ssl-ca");
|
||||
let ca_file = url.param("ssl-ca");
|
||||
|
||||
let ssl_mode = url.get_param("ssl-mode").unwrap_or(
|
||||
let ssl_mode = url.param("ssl-mode").unwrap_or(
|
||||
if ca_file.is_some() {
|
||||
"VERIFY_CA"
|
||||
} else {
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use std::borrow::Cow;
|
||||
use std::collections::HashMap;
|
||||
use std::convert::TryInto;
|
||||
use std::ops::Range;
|
||||
@@ -18,12 +19,10 @@ use crate::postgres::protocol::{
|
||||
};
|
||||
use crate::postgres::sasl;
|
||||
use crate::postgres::stream::PgStream;
|
||||
use crate::postgres::{PgError, PgTypeInfo};
|
||||
use crate::postgres::{tls, PgError, PgTypeInfo};
|
||||
use crate::url::Url;
|
||||
use crate::{Error, Executor, Postgres};
|
||||
|
||||
// TODO: TLS
|
||||
|
||||
/// An asynchronous connection to a [Postgres][super::Postgres] database.
|
||||
///
|
||||
/// The connection string expected by [Connect::connect] should be a PostgreSQL connection
|
||||
@@ -237,6 +236,7 @@ impl PgConnection {
|
||||
let url = url?;
|
||||
let mut stream = PgStream::new(&url).await?;
|
||||
|
||||
tls::request_if_needed(&mut stream, &url).await?;
|
||||
startup(&mut stream, &url).await?;
|
||||
|
||||
Ok(Self {
|
||||
|
||||
@@ -18,7 +18,7 @@ mod protocol;
|
||||
mod row;
|
||||
mod sasl;
|
||||
mod stream;
|
||||
// mod tls;
|
||||
mod tls;
|
||||
mod types;
|
||||
|
||||
/// An alias for [`Pool`][crate::Pool], specialized for **Postgres**.
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
use byteorder::NetworkEndian;
|
||||
|
||||
use crate::io::BufMut;
|
||||
use crate::postgres::protocol::Encode;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct SslRequest;
|
||||
|
||||
impl SslRequest {
|
||||
pub fn encode(buf: &mut Vec<u8>) {
|
||||
impl Encode for SslRequest {
|
||||
fn encode(&self, buf: &mut Vec<u8>) {
|
||||
// packet length: 8 bytes including self
|
||||
buf.put_u32::<NetworkEndian>(8);
|
||||
// 1234 in high 16 bits, 5679 in low 16
|
||||
@@ -15,6 +17,7 @@ impl SslRequest {
|
||||
|
||||
#[test]
|
||||
fn test_ssl_request() {
|
||||
use crate::encode::Encode;
|
||||
use crate::io::Buf;
|
||||
|
||||
let mut buf = Vec::new();
|
||||
|
||||
@@ -1,66 +1,145 @@
|
||||
use crate::postgres::protocol::SslRequest;
|
||||
use crate::postgres::stream::PgStream;
|
||||
use crate::postgres::PgConnection;
|
||||
use crate::url::Url;
|
||||
use std::borrow::Cow;
|
||||
use std::fs::read;
|
||||
|
||||
impl PgConnection {
|
||||
#[cfg(feature = "tls")]
|
||||
pub(super) async fn try_ssl(
|
||||
&mut self,
|
||||
url: &Url,
|
||||
invalid_certs: bool,
|
||||
invalid_hostnames: bool,
|
||||
) -> crate::Result<bool> {
|
||||
use async_native_tls::TlsConnector;
|
||||
|
||||
SslRequest::encode(self.stream.buffer_mut());
|
||||
|
||||
self.stream.flush().await?;
|
||||
|
||||
match self.stream.peek(1).await? {
|
||||
Some(b"N") => return Ok(false),
|
||||
Some(b"S") => (),
|
||||
Some(other) => {
|
||||
return Err(tls_err!("unexpected single-byte response: 0x{:02X}", other[0]).into())
|
||||
}
|
||||
None => return Err(tls_err!("server unexpectedly closed connection").into()),
|
||||
pub(crate) async fn request_if_needed(stream: &mut PgStream, url: &Url) -> crate::Result<()> {
|
||||
// https://www.postgresql.org/docs/12/libpq-ssl.html#LIBPQ-SSL-SSLMODE-STATEMENTS
|
||||
match url.param("sslmode").as_deref() {
|
||||
Some("disable") | Some("allow") => {
|
||||
// Do nothing
|
||||
}
|
||||
|
||||
let mut connector = TlsConnector::new()
|
||||
.danger_accept_invalid_certs(invalid_certs)
|
||||
.danger_accept_invalid_hostnames(invalid_hostnames);
|
||||
|
||||
if !invalid_certs {
|
||||
match read_root_certificate(&url).await {
|
||||
Ok(cert) => {
|
||||
connector = connector.add_root_certificate(cert);
|
||||
}
|
||||
Err(e) => log::warn!("failed to read Postgres root certificate: {}", e),
|
||||
#[cfg(feature = "tls")]
|
||||
Some("prefer") | None => {
|
||||
// We default to [prefer] if TLS is compiled in
|
||||
if !try_upgrade(stream, url, false, false).await? {
|
||||
// TLS upgrade failed; fall back to a normal connection
|
||||
}
|
||||
}
|
||||
|
||||
self.stream.clear_bufs();
|
||||
self.stream.stream.upgrade(url, connector).await?;
|
||||
#[cfg(not(feature = "tls"))]
|
||||
None => {
|
||||
// The user neither explicitly enabled TLS in the connection string
|
||||
// nor did they turn the `tls` feature on
|
||||
|
||||
Ok(true)
|
||||
// Do nothing
|
||||
}
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
Some(mode @ "require") | Some(mode @ "verify-ca") | Some(mode @ "verify-full") => {
|
||||
if !try_upgrade(
|
||||
stream,
|
||||
url,
|
||||
// false for both verify-ca and verify-full
|
||||
mode == "require",
|
||||
// false for only verify-full
|
||||
mode != "verify-full",
|
||||
)
|
||||
.await?
|
||||
{
|
||||
return Err(tls_err!("server does not support TLS").into());
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "tls"))]
|
||||
Some(mode @ "prefer")
|
||||
| Some(mode @ "require")
|
||||
| Some(mode @ "verify-ca")
|
||||
| Some(mode @ "verify-full") => {
|
||||
return Err(tls_err!(
|
||||
"sslmode {:?} unsupported; SQLx was compiled without `tls` feature",
|
||||
mode
|
||||
)
|
||||
.into());
|
||||
}
|
||||
|
||||
Some(mode) => {
|
||||
return Err(tls_err!("unknown `sslmode` value: {:?}", mode).into());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
async fn read_root_certificate(url: &Url) -> crate::Result<async_native_tls::Certificate> {
|
||||
async fn try_upgrade(
|
||||
stream: &mut PgStream,
|
||||
url: &Url,
|
||||
accept_invalid_certs: bool,
|
||||
accept_invalid_host_names: bool,
|
||||
) -> crate::Result<bool> {
|
||||
use async_native_tls::TlsConnector;
|
||||
|
||||
stream.write(SslRequest);
|
||||
stream.flush().await?;
|
||||
|
||||
// The server then responds with a single byte containing S or N,
|
||||
// indicating that it is willing or unwilling to perform SSL, respectively.
|
||||
let ind = stream.stream.peek(1).await?[0];
|
||||
stream.stream.consume(1);
|
||||
|
||||
match ind {
|
||||
b'S' => {
|
||||
// The server is ready and willing to accept an SSL connection
|
||||
}
|
||||
|
||||
b'N' => {
|
||||
// The server is _unwilling_ to perform SSL
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
other => {
|
||||
return Err(tls_err!("unexpected response from SSLRequest: 0x{:02X}", other).into());
|
||||
}
|
||||
}
|
||||
|
||||
let mut connector = TlsConnector::new()
|
||||
.danger_accept_invalid_certs(accept_invalid_certs)
|
||||
.danger_accept_invalid_hostnames(accept_invalid_host_names);
|
||||
|
||||
if !accept_invalid_certs {
|
||||
// Try to read in the root certificate for postgres using several
|
||||
// standard methods (used by psql and libpq)
|
||||
if let Some(cert) = read_root_certificate(&url).await? {
|
||||
connector = connector.add_root_certificate(cert);
|
||||
}
|
||||
}
|
||||
|
||||
stream.stream.upgrade(url, connector).await?;
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
async fn read_root_certificate(url: &Url) -> crate::Result<Option<async_native_tls::Certificate>> {
|
||||
use crate::runtime::fs;
|
||||
use std::env;
|
||||
|
||||
let root_cert_path = if let Some(path) = url.get_param("sslrootcert") {
|
||||
path.into()
|
||||
} else if let Ok(cert_path) = env::var("PGSSLROOTCERT") {
|
||||
cert_path
|
||||
} else if cfg!(windows) {
|
||||
let appdata = env::var("APPDATA").map_err(|_| tls_err!("APPDATA not set"))?;
|
||||
format!("{}\\postgresql\\root.crt", appdata)
|
||||
} else {
|
||||
let home = env::var("HOME").map_err(|_| tls_err!("HOME not set"))?;
|
||||
format!("{}/.postgresql/root.crt", home)
|
||||
};
|
||||
let mut data = None;
|
||||
|
||||
let root_cert = crate::runtime::fs::read(root_cert_path).await?;
|
||||
Ok(async_native_tls::Certificate::from_pem(&root_cert)?)
|
||||
if let Some(path) = url
|
||||
.param("sslrootcert")
|
||||
.or_else(|| env::var("PGSSLROOTCERT").ok().map(Into::into))
|
||||
{
|
||||
data = Some(fs::read(&*path).await?);
|
||||
} else if cfg!(windows) {
|
||||
if let Ok(app_data) = env::var("APPDATA") {
|
||||
let path = format!("{}\\postgresql\\root.crt", app_data);
|
||||
|
||||
data = fs::read(path).await.ok();
|
||||
}
|
||||
} else {
|
||||
if let Ok(home) = env::var("HOME") {
|
||||
let path = format!("{}/.postgresql/root.crt", home);
|
||||
|
||||
data = fs::read(path).await.ok();
|
||||
}
|
||||
}
|
||||
|
||||
data.map(|data| async_native_tls::Certificate::from_pem(&data))
|
||||
.transpose()
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
@@ -78,7 +78,7 @@ impl Url {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_param(&self, key: &str) -> Option<Cow<str>> {
|
||||
pub fn param(&self, key: &str) -> Option<Cow<str>> {
|
||||
self.0
|
||||
.query_pairs()
|
||||
.find_map(|(key_, val)| if key == key_ { Some(val) } else { None })
|
||||
|
||||
Reference in New Issue
Block a user