De-duplicate mysql & postgres TLS code

This commit is contained in:
Jonas Platte 2020-10-20 17:46:40 +02:00 committed by Ryan Leckey
parent f28ab22748
commit cd44b5eb43
3 changed files with 52 additions and 58 deletions

View File

@ -1,8 +1,3 @@
use sqlx_rt::{
fs,
native_tls::{Certificate, TlsConnector},
};
use crate::error::Error;
use crate::mysql::connection::MySqlStream;
use crate::mysql::protocol::connect::SslRequest;
@ -46,34 +41,20 @@ async fn upgrade(stream: &mut MySqlStream, options: &MySqlConnectOptions) -> Res
stream.flush().await?;
// FIXME: de-duplicate with postgres/connection/tls.rs
let accept_invalid_certs = !matches!(
options.ssl_mode,
MySqlSslMode::VerifyCa | MySqlSslMode::VerifyIdentity
);
let accept_invalid_host_names = !matches!(options.ssl_mode, MySqlSslMode::VerifyIdentity);
let mut builder = TlsConnector::builder();
builder
.danger_accept_invalid_certs(accept_invalid_certs)
.danger_accept_invalid_hostnames(!matches!(options.ssl_mode, MySqlSslMode::VerifyIdentity));
if !accept_invalid_certs {
if let Some(ca) = &options.ssl_ca {
let data = fs::read(ca).await?;
let cert = Certificate::from_pem(&data).map_err(Error::tls)?;
builder.add_root_certificate(cert);
}
}
#[cfg(not(feature = "_rt-async-std"))]
let connector = builder.build().map_err(Error::tls)?;
#[cfg(feature = "_rt-async-std")]
let connector = builder;
stream.upgrade(&options.host, connector.into()).await?;
stream
.upgrade(
&options.host,
accept_invalid_certs,
accept_invalid_host_names,
options.ssl_ca.as_deref(),
)
.await?;
Ok(true)
}

View File

@ -2,10 +2,15 @@
use std::io;
use std::ops::{Deref, DerefMut};
use std::path::Path;
use std::pin::Pin;
use std::task::{Context, Poll};
use sqlx_rt::{AsyncRead, AsyncWrite, TlsConnector, TlsStream};
use sqlx_rt::{
fs,
native_tls::{Certificate, TlsConnector},
AsyncRead, AsyncWrite, TlsStream,
};
use crate::error::Error;
use std::mem::replace;
@ -28,7 +33,33 @@ where
matches!(self, Self::Tls(_))
}
pub async fn upgrade(&mut self, host: &str, connector: TlsConnector) -> Result<(), Error> {
pub async fn upgrade(
&mut self,
host: &str,
accept_invalid_certs: bool,
accept_invalid_hostnames: bool,
root_cert_path: Option<&Path>,
) -> Result<(), Error> {
let mut builder = TlsConnector::builder();
builder
.danger_accept_invalid_certs(accept_invalid_certs)
.danger_accept_invalid_hostnames(accept_invalid_hostnames);
if !accept_invalid_certs {
if let Some(ca) = root_cert_path {
let data = fs::read(ca).await?;
let cert = Certificate::from_pem(&data).map_err(Error::tls)?;
builder.add_root_certificate(cert);
}
}
#[cfg(not(feature = "_rt-async-std"))]
let connector = builder.build().map_err(Error::tls)?;
#[cfg(feature = "_rt-async-std")]
let connector = builder;
let stream = match replace(self, MaybeTlsStream::Upgrading) {
MaybeTlsStream::Raw(stream) => stream,
@ -45,7 +76,7 @@ where
};
*self = MaybeTlsStream::Tls(
connector
sqlx_rt::TlsConnector::from(connector)
.connect(host, stream)
.await
.map_err(|err| Error::Tls(err.into()))?,

View File

@ -1,8 +1,4 @@
use bytes::Bytes;
use sqlx_rt::{
fs,
native_tls::{Certificate, TlsConnector},
};
use crate::error::Error;
use crate::postgres::connection::stream::PgStream;
@ -63,34 +59,20 @@ async fn upgrade(stream: &mut PgStream, options: &PgConnectOptions) -> Result<bo
}
}
// FIXME: de-duplicate with mysql/connection/tls.rs
let accept_invalid_certs = !matches!(
options.ssl_mode,
PgSslMode::VerifyCa | PgSslMode::VerifyFull
);
let accept_invalid_hostnames = !matches!(options.ssl_mode, PgSslMode::VerifyFull);
let mut builder = TlsConnector::builder();
builder
.danger_accept_invalid_certs(accept_invalid_certs)
.danger_accept_invalid_hostnames(!matches!(options.ssl_mode, PgSslMode::VerifyFull));
if !accept_invalid_certs {
if let Some(ca) = &options.ssl_root_cert {
let data = fs::read(ca).await?;
let cert = Certificate::from_pem(&data).map_err(Error::tls)?;
builder.add_root_certificate(cert);
}
}
#[cfg(not(feature = "_rt-async-std"))]
let connector = builder.build().map_err(Error::tls)?;
#[cfg(feature = "_rt-async-std")]
let connector = builder;
stream.upgrade(&options.host, connector.into()).await?;
stream
.upgrade(
&options.host,
accept_invalid_certs,
accept_invalid_hostnames,
options.ssl_root_cert.as_deref(),
)
.await?;
Ok(true)
}