Add rustls support

This commit is contained in:
Jonas Platte
2020-10-20 12:21:01 +02:00
committed by Ryan Leckey
parent 9298c88b87
commit b14266ba2e
12 changed files with 348 additions and 113 deletions

View File

@@ -38,11 +38,16 @@ runtime-actix-native-tls = [ "sqlx-rt/runtime-actix-native-tls", "_tls-native-tl
runtime-async-std-native-tls = [ "sqlx-rt/runtime-async-std-native-tls", "_tls-native-tls", "_rt-async-std" ]
runtime-tokio-native-tls = [ "sqlx-rt/runtime-tokio-native-tls", "_tls-native-tls", "_rt-tokio" ]
runtime-actix-rustls = [ "sqlx-rt/runtime-actix-rustls", "_tls-rustls", "_rt-actix" ]
runtime-async-std-rustls = [ "sqlx-rt/runtime-async-std-rustls", "_tls-rustls", "_rt-async-std" ]
runtime-tokio-rustls = [ "sqlx-rt/runtime-tokio-rustls", "_tls-rustls", "_rt-tokio" ]
# for conditional compilation
_rt-actix = []
_rt-async-std = []
_rt-tokio = []
_tls-native-tls = []
_tls-rustls = [ "rustls", "webpki" ]
# support offline/decoupled building (enables serialization of `Describe`)
offline = [ "serde", "either/serde" ]
@@ -86,6 +91,7 @@ parking_lot = "0.11.0"
rand = { version = "0.7.3", default-features = false, optional = true, features = [ "std" ] }
regex = { version = "1.3.9", optional = true }
rsa = { version = "0.3.0", optional = true }
rustls = { version = "0.18.1", optional = true }
serde = { version = "1.0.106", features = [ "derive", "rc" ], optional = true }
serde_json = { version = "1.0.51", features = [ "raw_value" ], optional = true }
sha-1 = { version = "0.9.0", default-features = false, optional = true }
@@ -96,6 +102,7 @@ time = { version = "0.2.16", optional = true }
smallvec = "1.4.0"
url = { version = "2.1.1", default-features = false }
uuid = { version = "0.8.1", default-features = false, optional = true, features = [ "std" ] }
webpki = { version = "0.21.3", optional = true }
whoami = "0.9.0"
stringprep = "0.1.2"
lru-cache = "0.1.2"

View File

@@ -242,6 +242,14 @@ impl From<sqlx_rt::native_tls::Error> for Error {
}
}
#[cfg(feature = "_tls-rustls")]
impl From<webpki::InvalidDNSNameError> for Error {
#[inline]
fn from(error: webpki::InvalidDNSNameError) -> Self {
Error::Tls(Box::new(error))
}
}
// Format an error message as a `Protocol` error
macro_rules! err_protocol {
($expr:expr) => {

View File

@@ -6,11 +6,7 @@ use std::path::Path;
use std::pin::Pin;
use std::task::{Context, Poll};
use sqlx_rt::{
fs,
native_tls::{Certificate, TlsConnector},
AsyncRead, AsyncWrite, TlsStream,
};
use sqlx_rt::{fs, AsyncRead, AsyncWrite, TlsStream};
use crate::error::Error;
use std::mem::replace;
@@ -40,25 +36,12 @@ where
accept_invalid_hostnames: bool,
root_cert_path: Option<&Path>,
) -> Result<(), Error> {
let mut builder = TlsConnector::builder();
builder
.danger_accept_invalid_certs(accept_invalid_certs)
.danger_accept_invalid_hostnames(accept_invalid_hostnames);
if !accept_invalid_certs {
if let Some(ca) = root_cert_path {
let data = fs::read(ca).await?;
let cert = Certificate::from_pem(&data)?;
builder.add_root_certificate(cert);
}
}
#[cfg(not(feature = "_rt-async-std"))]
let connector = sqlx_rt::TlsConnector::from(builder.build()?);
#[cfg(feature = "_rt-async-std")]
let connector = sqlx_rt::TlsConnector::from(builder);
let connector = configure_tls_connector(
accept_invalid_certs,
accept_invalid_hostnames,
root_cert_path,
)
.await?;
let stream = match replace(self, MaybeTlsStream::Upgrading) {
MaybeTlsStream::Raw(stream) => stream,
@@ -75,12 +58,71 @@ where
}
};
#[cfg(feature = "_tls-rustls")]
let host = webpki::DNSNameRef::try_from_ascii_str(host)?;
*self = MaybeTlsStream::Tls(connector.connect(host, stream).await?);
Ok(())
}
}
#[cfg(feature = "_tls-native-tls")]
async fn configure_tls_connector(
accept_invalid_certs: bool,
accept_invalid_hostnames: bool,
root_cert_path: Option<&Path>,
) -> Result<sqlx_rt::TlsConnector, Error> {
use sqlx_rt::native_tls::{Certificate, TlsConnector};
let mut builder = TlsConnector::builder();
builder
.danger_accept_invalid_certs(accept_invalid_certs)
.danger_accept_invalid_hostnames(accept_invalid_hostnames);
if !accept_invalid_certs {
if let Some(ca) = root_cert_path {
let data = fs::read(ca).await?;
let cert = Certificate::from_pem(&data)?;
builder.add_root_certificate(cert);
}
}
#[cfg(not(feature = "_rt-async-std"))]
let connector = builder.build()?.into();
#[cfg(feature = "_rt-async-std")]
let connector = builder.into();
Ok(connector)
}
#[cfg(feature = "_tls-rustls")]
async fn configure_tls_connector(
_accept_invalid_certs: bool,
_accept_invalid_hostnames: bool,
root_cert_path: Option<&Path>,
) -> Result<sqlx_rt::TlsConnector, Error> {
// FIXME: Support accept_invalid_certs / accept_invalid_hostnames
use rustls::ClientConfig;
use std::io::Cursor;
use std::sync::Arc;
let mut config = ClientConfig::new();
if let Some(ca) = root_cert_path {
let data = fs::read(ca).await?;
let mut cursor = Cursor::new(data);
config.root_store.add_pem_file(&mut cursor).map_err(|_| {
Error::Tls(format!("Invalid certificate file: {}", ca.display()).into())
})?;
}
Ok(Arc::new(config).into())
}
impl<S> AsyncRead for MaybeTlsStream<S>
where
S: Unpin + AsyncWrite + AsyncRead,
@@ -192,12 +234,15 @@ where
match self {
MaybeTlsStream::Raw(s) => s,
#[cfg(not(feature = "_rt-async-std"))]
MaybeTlsStream::Tls(s) => s.get_ref().get_ref().get_ref(),
#[cfg(feature = "_tls-rustls")]
MaybeTlsStream::Tls(s) => s.get_ref().0,
#[cfg(feature = "_rt-async-std")]
#[cfg(all(feature = "_rt-async-std", feature = "_tls-native-tls"))]
MaybeTlsStream::Tls(s) => s.get_ref(),
#[cfg(all(not(feature = "_rt-async-std"), feature = "_tls-native-tls"))]
MaybeTlsStream::Tls(s) => s.get_ref().get_ref().get_ref(),
MaybeTlsStream::Upgrading => panic!(io::Error::from(io::ErrorKind::ConnectionAborted)),
}
}
@@ -211,12 +256,15 @@ where
match self {
MaybeTlsStream::Raw(s) => s,
#[cfg(not(feature = "_rt-async-std"))]
MaybeTlsStream::Tls(s) => s.get_mut().get_mut().get_mut(),
#[cfg(feature = "_tls-rustls")]
MaybeTlsStream::Tls(s) => s.get_mut().0,
#[cfg(feature = "_rt-async-std")]
#[cfg(all(feature = "_rt-async-std", feature = "_tls-native-tls"))]
MaybeTlsStream::Tls(s) => s.get_mut(),
#[cfg(all(not(feature = "_rt-async-std"), feature = "_tls-native-tls"))]
MaybeTlsStream::Tls(s) => s.get_mut().get_mut().get_mut(),
MaybeTlsStream::Upgrading => panic!(io::Error::from(io::ErrorKind::ConnectionAborted)),
}
}