diff --git a/sqlx-core/src/mysql/connection/tls.rs b/sqlx-core/src/mysql/connection/tls.rs index 2ed47a6e..2b2a5b80 100644 --- a/sqlx-core/src/mysql/connection/tls.rs +++ b/sqlx-core/src/mysql/connection/tls.rs @@ -1,8 +1,3 @@ -use sqlx_rt::{ - fs, - native_tls::{Certificate, TlsConnector}, -}; - use crate::error::Error; use crate::mysql::connection::MySqlStream; use crate::mysql::protocol::connect::SslRequest; @@ -46,34 +41,20 @@ async fn upgrade(stream: &mut MySqlStream, options: &MySqlConnectOptions) -> Res stream.flush().await?; - // FIXME: de-duplicate with postgres/connection/tls.rs - let accept_invalid_certs = !matches!( options.ssl_mode, MySqlSslMode::VerifyCa | MySqlSslMode::VerifyIdentity ); + let accept_invalid_host_names = !matches!(options.ssl_mode, MySqlSslMode::VerifyIdentity); - let mut builder = TlsConnector::builder(); - builder - .danger_accept_invalid_certs(accept_invalid_certs) - .danger_accept_invalid_hostnames(!matches!(options.ssl_mode, MySqlSslMode::VerifyIdentity)); - - if !accept_invalid_certs { - if let Some(ca) = &options.ssl_ca { - let data = fs::read(ca).await?; - let cert = Certificate::from_pem(&data).map_err(Error::tls)?; - - builder.add_root_certificate(cert); - } - } - - #[cfg(not(feature = "_rt-async-std"))] - let connector = builder.build().map_err(Error::tls)?; - - #[cfg(feature = "_rt-async-std")] - let connector = builder; - - stream.upgrade(&options.host, connector.into()).await?; + stream + .upgrade( + &options.host, + accept_invalid_certs, + accept_invalid_host_names, + options.ssl_ca.as_deref(), + ) + .await?; Ok(true) } diff --git a/sqlx-core/src/net/tls.rs b/sqlx-core/src/net/tls.rs index 89c1b91d..674babfa 100644 --- a/sqlx-core/src/net/tls.rs +++ b/sqlx-core/src/net/tls.rs @@ -2,10 +2,15 @@ use std::io; use std::ops::{Deref, DerefMut}; +use std::path::Path; use std::pin::Pin; use std::task::{Context, Poll}; -use sqlx_rt::{AsyncRead, AsyncWrite, TlsConnector, TlsStream}; +use sqlx_rt::{ + fs, + native_tls::{Certificate, TlsConnector}, + AsyncRead, AsyncWrite, TlsStream, +}; use crate::error::Error; use std::mem::replace; @@ -28,7 +33,33 @@ where matches!(self, Self::Tls(_)) } - pub async fn upgrade(&mut self, host: &str, connector: TlsConnector) -> Result<(), Error> { + pub async fn upgrade( + &mut self, + host: &str, + accept_invalid_certs: bool, + accept_invalid_hostnames: bool, + root_cert_path: Option<&Path>, + ) -> Result<(), Error> { + let mut builder = TlsConnector::builder(); + builder + .danger_accept_invalid_certs(accept_invalid_certs) + .danger_accept_invalid_hostnames(accept_invalid_hostnames); + + if !accept_invalid_certs { + if let Some(ca) = root_cert_path { + let data = fs::read(ca).await?; + let cert = Certificate::from_pem(&data).map_err(Error::tls)?; + + builder.add_root_certificate(cert); + } + } + + #[cfg(not(feature = "_rt-async-std"))] + let connector = builder.build().map_err(Error::tls)?; + + #[cfg(feature = "_rt-async-std")] + let connector = builder; + let stream = match replace(self, MaybeTlsStream::Upgrading) { MaybeTlsStream::Raw(stream) => stream, @@ -45,7 +76,7 @@ where }; *self = MaybeTlsStream::Tls( - connector + sqlx_rt::TlsConnector::from(connector) .connect(host, stream) .await .map_err(|err| Error::Tls(err.into()))?, diff --git a/sqlx-core/src/postgres/connection/tls.rs b/sqlx-core/src/postgres/connection/tls.rs index 03cbd6ae..283cc1b1 100644 --- a/sqlx-core/src/postgres/connection/tls.rs +++ b/sqlx-core/src/postgres/connection/tls.rs @@ -1,8 +1,4 @@ use bytes::Bytes; -use sqlx_rt::{ - fs, - native_tls::{Certificate, TlsConnector}, -}; use crate::error::Error; use crate::postgres::connection::stream::PgStream; @@ -63,34 +59,20 @@ async fn upgrade(stream: &mut PgStream, options: &PgConnectOptions) -> Result