mirror of
https://github.com/launchbadge/sqlx.git
synced 2025-12-29 21:00:54 +00:00
De-duplicate mysql & postgres TLS code
This commit is contained in:
parent
f28ab22748
commit
cd44b5eb43
@ -1,8 +1,3 @@
|
||||
use sqlx_rt::{
|
||||
fs,
|
||||
native_tls::{Certificate, TlsConnector},
|
||||
};
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::mysql::connection::MySqlStream;
|
||||
use crate::mysql::protocol::connect::SslRequest;
|
||||
@ -46,34 +41,20 @@ async fn upgrade(stream: &mut MySqlStream, options: &MySqlConnectOptions) -> Res
|
||||
|
||||
stream.flush().await?;
|
||||
|
||||
// FIXME: de-duplicate with postgres/connection/tls.rs
|
||||
|
||||
let accept_invalid_certs = !matches!(
|
||||
options.ssl_mode,
|
||||
MySqlSslMode::VerifyCa | MySqlSslMode::VerifyIdentity
|
||||
);
|
||||
let accept_invalid_host_names = !matches!(options.ssl_mode, MySqlSslMode::VerifyIdentity);
|
||||
|
||||
let mut builder = TlsConnector::builder();
|
||||
builder
|
||||
.danger_accept_invalid_certs(accept_invalid_certs)
|
||||
.danger_accept_invalid_hostnames(!matches!(options.ssl_mode, MySqlSslMode::VerifyIdentity));
|
||||
|
||||
if !accept_invalid_certs {
|
||||
if let Some(ca) = &options.ssl_ca {
|
||||
let data = fs::read(ca).await?;
|
||||
let cert = Certificate::from_pem(&data).map_err(Error::tls)?;
|
||||
|
||||
builder.add_root_certificate(cert);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "_rt-async-std"))]
|
||||
let connector = builder.build().map_err(Error::tls)?;
|
||||
|
||||
#[cfg(feature = "_rt-async-std")]
|
||||
let connector = builder;
|
||||
|
||||
stream.upgrade(&options.host, connector.into()).await?;
|
||||
stream
|
||||
.upgrade(
|
||||
&options.host,
|
||||
accept_invalid_certs,
|
||||
accept_invalid_host_names,
|
||||
options.ssl_ca.as_deref(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
@ -2,10 +2,15 @@
|
||||
|
||||
use std::io;
|
||||
use std::ops::{Deref, DerefMut};
|
||||
use std::path::Path;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
use sqlx_rt::{AsyncRead, AsyncWrite, TlsConnector, TlsStream};
|
||||
use sqlx_rt::{
|
||||
fs,
|
||||
native_tls::{Certificate, TlsConnector},
|
||||
AsyncRead, AsyncWrite, TlsStream,
|
||||
};
|
||||
|
||||
use crate::error::Error;
|
||||
use std::mem::replace;
|
||||
@ -28,7 +33,33 @@ where
|
||||
matches!(self, Self::Tls(_))
|
||||
}
|
||||
|
||||
pub async fn upgrade(&mut self, host: &str, connector: TlsConnector) -> Result<(), Error> {
|
||||
pub async fn upgrade(
|
||||
&mut self,
|
||||
host: &str,
|
||||
accept_invalid_certs: bool,
|
||||
accept_invalid_hostnames: bool,
|
||||
root_cert_path: Option<&Path>,
|
||||
) -> Result<(), Error> {
|
||||
let mut builder = TlsConnector::builder();
|
||||
builder
|
||||
.danger_accept_invalid_certs(accept_invalid_certs)
|
||||
.danger_accept_invalid_hostnames(accept_invalid_hostnames);
|
||||
|
||||
if !accept_invalid_certs {
|
||||
if let Some(ca) = root_cert_path {
|
||||
let data = fs::read(ca).await?;
|
||||
let cert = Certificate::from_pem(&data).map_err(Error::tls)?;
|
||||
|
||||
builder.add_root_certificate(cert);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "_rt-async-std"))]
|
||||
let connector = builder.build().map_err(Error::tls)?;
|
||||
|
||||
#[cfg(feature = "_rt-async-std")]
|
||||
let connector = builder;
|
||||
|
||||
let stream = match replace(self, MaybeTlsStream::Upgrading) {
|
||||
MaybeTlsStream::Raw(stream) => stream,
|
||||
|
||||
@ -45,7 +76,7 @@ where
|
||||
};
|
||||
|
||||
*self = MaybeTlsStream::Tls(
|
||||
connector
|
||||
sqlx_rt::TlsConnector::from(connector)
|
||||
.connect(host, stream)
|
||||
.await
|
||||
.map_err(|err| Error::Tls(err.into()))?,
|
||||
|
||||
@ -1,8 +1,4 @@
|
||||
use bytes::Bytes;
|
||||
use sqlx_rt::{
|
||||
fs,
|
||||
native_tls::{Certificate, TlsConnector},
|
||||
};
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::postgres::connection::stream::PgStream;
|
||||
@ -63,34 +59,20 @@ async fn upgrade(stream: &mut PgStream, options: &PgConnectOptions) -> Result<bo
|
||||
}
|
||||
}
|
||||
|
||||
// FIXME: de-duplicate with mysql/connection/tls.rs
|
||||
|
||||
let accept_invalid_certs = !matches!(
|
||||
options.ssl_mode,
|
||||
PgSslMode::VerifyCa | PgSslMode::VerifyFull
|
||||
);
|
||||
let accept_invalid_hostnames = !matches!(options.ssl_mode, PgSslMode::VerifyFull);
|
||||
|
||||
let mut builder = TlsConnector::builder();
|
||||
builder
|
||||
.danger_accept_invalid_certs(accept_invalid_certs)
|
||||
.danger_accept_invalid_hostnames(!matches!(options.ssl_mode, PgSslMode::VerifyFull));
|
||||
|
||||
if !accept_invalid_certs {
|
||||
if let Some(ca) = &options.ssl_root_cert {
|
||||
let data = fs::read(ca).await?;
|
||||
let cert = Certificate::from_pem(&data).map_err(Error::tls)?;
|
||||
|
||||
builder.add_root_certificate(cert);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "_rt-async-std"))]
|
||||
let connector = builder.build().map_err(Error::tls)?;
|
||||
|
||||
#[cfg(feature = "_rt-async-std")]
|
||||
let connector = builder;
|
||||
|
||||
stream.upgrade(&options.host, connector.into()).await?;
|
||||
stream
|
||||
.upgrade(
|
||||
&options.host,
|
||||
accept_invalid_certs,
|
||||
accept_invalid_hostnames,
|
||||
options.ssl_root_cert.as_deref(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user