mirror of
https://github.com/launchbadge/sqlx.git
synced 2026-04-04 09:31:05 +00:00
Add rustls support
This commit is contained in:
committed by
Ryan Leckey
parent
9298c88b87
commit
b14266ba2e
@@ -38,11 +38,16 @@ runtime-actix-native-tls = [ "sqlx-rt/runtime-actix-native-tls", "_tls-native-tl
|
||||
runtime-async-std-native-tls = [ "sqlx-rt/runtime-async-std-native-tls", "_tls-native-tls", "_rt-async-std" ]
|
||||
runtime-tokio-native-tls = [ "sqlx-rt/runtime-tokio-native-tls", "_tls-native-tls", "_rt-tokio" ]
|
||||
|
||||
runtime-actix-rustls = [ "sqlx-rt/runtime-actix-rustls", "_tls-rustls", "_rt-actix" ]
|
||||
runtime-async-std-rustls = [ "sqlx-rt/runtime-async-std-rustls", "_tls-rustls", "_rt-async-std" ]
|
||||
runtime-tokio-rustls = [ "sqlx-rt/runtime-tokio-rustls", "_tls-rustls", "_rt-tokio" ]
|
||||
|
||||
# for conditional compilation
|
||||
_rt-actix = []
|
||||
_rt-async-std = []
|
||||
_rt-tokio = []
|
||||
_tls-native-tls = []
|
||||
_tls-rustls = [ "rustls", "webpki" ]
|
||||
|
||||
# support offline/decoupled building (enables serialization of `Describe`)
|
||||
offline = [ "serde", "either/serde" ]
|
||||
@@ -86,6 +91,7 @@ parking_lot = "0.11.0"
|
||||
rand = { version = "0.7.3", default-features = false, optional = true, features = [ "std" ] }
|
||||
regex = { version = "1.3.9", optional = true }
|
||||
rsa = { version = "0.3.0", optional = true }
|
||||
rustls = { version = "0.18.1", optional = true }
|
||||
serde = { version = "1.0.106", features = [ "derive", "rc" ], optional = true }
|
||||
serde_json = { version = "1.0.51", features = [ "raw_value" ], optional = true }
|
||||
sha-1 = { version = "0.9.0", default-features = false, optional = true }
|
||||
@@ -96,6 +102,7 @@ time = { version = "0.2.16", optional = true }
|
||||
smallvec = "1.4.0"
|
||||
url = { version = "2.1.1", default-features = false }
|
||||
uuid = { version = "0.8.1", default-features = false, optional = true, features = [ "std" ] }
|
||||
webpki = { version = "0.21.3", optional = true }
|
||||
whoami = "0.9.0"
|
||||
stringprep = "0.1.2"
|
||||
lru-cache = "0.1.2"
|
||||
|
||||
@@ -242,6 +242,14 @@ impl From<sqlx_rt::native_tls::Error> for Error {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "_tls-rustls")]
|
||||
impl From<webpki::InvalidDNSNameError> for Error {
|
||||
#[inline]
|
||||
fn from(error: webpki::InvalidDNSNameError) -> Self {
|
||||
Error::Tls(Box::new(error))
|
||||
}
|
||||
}
|
||||
|
||||
// Format an error message as a `Protocol` error
|
||||
macro_rules! err_protocol {
|
||||
($expr:expr) => {
|
||||
|
||||
@@ -6,11 +6,7 @@ use std::path::Path;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
use sqlx_rt::{
|
||||
fs,
|
||||
native_tls::{Certificate, TlsConnector},
|
||||
AsyncRead, AsyncWrite, TlsStream,
|
||||
};
|
||||
use sqlx_rt::{fs, AsyncRead, AsyncWrite, TlsStream};
|
||||
|
||||
use crate::error::Error;
|
||||
use std::mem::replace;
|
||||
@@ -40,25 +36,12 @@ where
|
||||
accept_invalid_hostnames: bool,
|
||||
root_cert_path: Option<&Path>,
|
||||
) -> Result<(), Error> {
|
||||
let mut builder = TlsConnector::builder();
|
||||
builder
|
||||
.danger_accept_invalid_certs(accept_invalid_certs)
|
||||
.danger_accept_invalid_hostnames(accept_invalid_hostnames);
|
||||
|
||||
if !accept_invalid_certs {
|
||||
if let Some(ca) = root_cert_path {
|
||||
let data = fs::read(ca).await?;
|
||||
let cert = Certificate::from_pem(&data)?;
|
||||
|
||||
builder.add_root_certificate(cert);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "_rt-async-std"))]
|
||||
let connector = sqlx_rt::TlsConnector::from(builder.build()?);
|
||||
|
||||
#[cfg(feature = "_rt-async-std")]
|
||||
let connector = sqlx_rt::TlsConnector::from(builder);
|
||||
let connector = configure_tls_connector(
|
||||
accept_invalid_certs,
|
||||
accept_invalid_hostnames,
|
||||
root_cert_path,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let stream = match replace(self, MaybeTlsStream::Upgrading) {
|
||||
MaybeTlsStream::Raw(stream) => stream,
|
||||
@@ -75,12 +58,71 @@ where
|
||||
}
|
||||
};
|
||||
|
||||
#[cfg(feature = "_tls-rustls")]
|
||||
let host = webpki::DNSNameRef::try_from_ascii_str(host)?;
|
||||
|
||||
*self = MaybeTlsStream::Tls(connector.connect(host, stream).await?);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "_tls-native-tls")]
|
||||
async fn configure_tls_connector(
|
||||
accept_invalid_certs: bool,
|
||||
accept_invalid_hostnames: bool,
|
||||
root_cert_path: Option<&Path>,
|
||||
) -> Result<sqlx_rt::TlsConnector, Error> {
|
||||
use sqlx_rt::native_tls::{Certificate, TlsConnector};
|
||||
|
||||
let mut builder = TlsConnector::builder();
|
||||
builder
|
||||
.danger_accept_invalid_certs(accept_invalid_certs)
|
||||
.danger_accept_invalid_hostnames(accept_invalid_hostnames);
|
||||
|
||||
if !accept_invalid_certs {
|
||||
if let Some(ca) = root_cert_path {
|
||||
let data = fs::read(ca).await?;
|
||||
let cert = Certificate::from_pem(&data)?;
|
||||
|
||||
builder.add_root_certificate(cert);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "_rt-async-std"))]
|
||||
let connector = builder.build()?.into();
|
||||
|
||||
#[cfg(feature = "_rt-async-std")]
|
||||
let connector = builder.into();
|
||||
|
||||
Ok(connector)
|
||||
}
|
||||
|
||||
#[cfg(feature = "_tls-rustls")]
|
||||
async fn configure_tls_connector(
|
||||
_accept_invalid_certs: bool,
|
||||
_accept_invalid_hostnames: bool,
|
||||
root_cert_path: Option<&Path>,
|
||||
) -> Result<sqlx_rt::TlsConnector, Error> {
|
||||
// FIXME: Support accept_invalid_certs / accept_invalid_hostnames
|
||||
|
||||
use rustls::ClientConfig;
|
||||
use std::io::Cursor;
|
||||
use std::sync::Arc;
|
||||
|
||||
let mut config = ClientConfig::new();
|
||||
|
||||
if let Some(ca) = root_cert_path {
|
||||
let data = fs::read(ca).await?;
|
||||
let mut cursor = Cursor::new(data);
|
||||
config.root_store.add_pem_file(&mut cursor).map_err(|_| {
|
||||
Error::Tls(format!("Invalid certificate file: {}", ca.display()).into())
|
||||
})?;
|
||||
}
|
||||
|
||||
Ok(Arc::new(config).into())
|
||||
}
|
||||
|
||||
impl<S> AsyncRead for MaybeTlsStream<S>
|
||||
where
|
||||
S: Unpin + AsyncWrite + AsyncRead,
|
||||
@@ -192,12 +234,15 @@ where
|
||||
match self {
|
||||
MaybeTlsStream::Raw(s) => s,
|
||||
|
||||
#[cfg(not(feature = "_rt-async-std"))]
|
||||
MaybeTlsStream::Tls(s) => s.get_ref().get_ref().get_ref(),
|
||||
#[cfg(feature = "_tls-rustls")]
|
||||
MaybeTlsStream::Tls(s) => s.get_ref().0,
|
||||
|
||||
#[cfg(feature = "_rt-async-std")]
|
||||
#[cfg(all(feature = "_rt-async-std", feature = "_tls-native-tls"))]
|
||||
MaybeTlsStream::Tls(s) => s.get_ref(),
|
||||
|
||||
#[cfg(all(not(feature = "_rt-async-std"), feature = "_tls-native-tls"))]
|
||||
MaybeTlsStream::Tls(s) => s.get_ref().get_ref().get_ref(),
|
||||
|
||||
MaybeTlsStream::Upgrading => panic!(io::Error::from(io::ErrorKind::ConnectionAborted)),
|
||||
}
|
||||
}
|
||||
@@ -211,12 +256,15 @@ where
|
||||
match self {
|
||||
MaybeTlsStream::Raw(s) => s,
|
||||
|
||||
#[cfg(not(feature = "_rt-async-std"))]
|
||||
MaybeTlsStream::Tls(s) => s.get_mut().get_mut().get_mut(),
|
||||
#[cfg(feature = "_tls-rustls")]
|
||||
MaybeTlsStream::Tls(s) => s.get_mut().0,
|
||||
|
||||
#[cfg(feature = "_rt-async-std")]
|
||||
#[cfg(all(feature = "_rt-async-std", feature = "_tls-native-tls"))]
|
||||
MaybeTlsStream::Tls(s) => s.get_mut(),
|
||||
|
||||
#[cfg(all(not(feature = "_rt-async-std"), feature = "_tls-native-tls"))]
|
||||
MaybeTlsStream::Tls(s) => s.get_mut().get_mut().get_mut(),
|
||||
|
||||
MaybeTlsStream::Upgrading => panic!(io::Error::from(io::ErrorKind::ConnectionAborted)),
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user