Merge branch 'alex-berger-feature/inline-certificates'

This commit is contained in:
Ryan Leckey 2021-01-20 22:08:35 -08:00
commit 05c1a8899a
No known key found for this signature in database
GPG Key ID: F8AA68C235AB08C9
7 changed files with 96 additions and 25 deletions

View File

@ -52,7 +52,7 @@ async fn upgrade(stream: &mut MySqlStream, options: &MySqlConnectOptions) -> Res
&options.host,
accept_invalid_certs,
accept_invalid_host_names,
options.ssl_ca.as_deref(),
options.ssl_ca.as_ref(),
)
.await?;

View File

@ -4,7 +4,7 @@ mod connect;
mod parse;
mod ssl_mode;
use crate::connection::LogSettings;
use crate::{connection::LogSettings, net::CertificateInput};
pub use ssl_mode::MySqlSslMode;
/// Options and flags which can be used to configure a MySQL connection.
@ -60,7 +60,7 @@ pub struct MySqlConnectOptions {
pub(crate) password: Option<String>,
pub(crate) database: Option<String>,
pub(crate) ssl_mode: MySqlSslMode,
pub(crate) ssl_ca: Option<PathBuf>,
pub(crate) ssl_ca: Option<CertificateInput>,
pub(crate) statement_cache_capacity: usize,
pub(crate) charset: String,
pub(crate) collation: Option<String>,
@ -165,7 +165,22 @@ impl MySqlConnectOptions {
/// .ssl_ca("path/to/ca.crt");
/// ```
pub fn ssl_ca(mut self, file_name: impl AsRef<Path>) -> Self {
self.ssl_ca = Some(file_name.as_ref().to_owned());
self.ssl_ca = Some(CertificateInput::File(file_name.as_ref().to_owned()));
self
}
/// Sets PEM encoded list of trusted SSL Certificate Authorities.
///
/// # Example
///
/// ```rust
/// # use sqlx_core::mysql::{MySqlSslMode, MySqlConnectOptions};
/// let options = MySqlConnectOptions::new()
/// .ssl_mode(MySqlSslMode::VerifyCa)
/// .ssl_ca_from_pem(vec![]);
/// ```
pub fn ssl_ca_from_pem(mut self, pem_certificate: Vec<u8>) -> Self {
self.ssl_ca = Some(CertificateInput::Inline(pem_certificate));
self
}

View File

@ -2,7 +2,7 @@ mod socket;
mod tls;
pub use socket::Socket;
pub use tls::MaybeTlsStream;
pub use tls::{CertificateInput, MaybeTlsStream};
#[cfg(feature = "_rt-async-std")]
type PollReadBuf<'a> = [u8];

View File

@ -2,7 +2,7 @@
use std::io;
use std::ops::{Deref, DerefMut};
use std::path::Path;
use std::path::PathBuf;
use std::pin::Pin;
use std::task::{Context, Poll};
@ -11,6 +11,48 @@ use sqlx_rt::{AsyncRead, AsyncWrite, TlsStream};
use crate::error::Error;
use std::mem::replace;
/// X.509 Certificate input, either a file path or a PEM encoded inline certificate(s).
#[derive(Clone, Debug)]
pub enum CertificateInput {
/// PEM encoded certificate(s)
Inline(Vec<u8>),
/// Path to a file containing PEM encoded certificate(s)
File(PathBuf),
}
impl From<String> for CertificateInput {
fn from(value: String) -> Self {
let trimmed = value.trim();
// Some heuristics according to https://tools.ietf.org/html/rfc7468
if trimmed.starts_with("-----BEGIN CERTIFICATE-----")
&& trimmed.contains("-----END CERTIFICATE-----")
{
CertificateInput::Inline(value.as_bytes().to_vec())
} else {
CertificateInput::File(PathBuf::from(value))
}
}
}
impl CertificateInput {
async fn data(&self) -> Result<Vec<u8>, std::io::Error> {
use sqlx_rt::fs;
match self {
CertificateInput::Inline(v) => Ok(v.clone()),
CertificateInput::File(path) => fs::read(path).await,
}
}
}
impl std::fmt::Display for CertificateInput {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CertificateInput::Inline(v) => write!(f, "{}", String::from_utf8_lossy(v.as_slice())),
CertificateInput::File(path) => write!(f, "file: {}", path.display()),
}
}
}
#[cfg(feature = "_tls-rustls")]
mod rustls;
@ -37,7 +79,7 @@ where
host: &str,
accept_invalid_certs: bool,
accept_invalid_hostnames: bool,
root_cert_path: Option<&Path>,
root_cert_path: Option<&CertificateInput>,
) -> Result<(), Error> {
let connector = configure_tls_connector(
accept_invalid_certs,
@ -74,12 +116,9 @@ where
async fn configure_tls_connector(
accept_invalid_certs: bool,
accept_invalid_hostnames: bool,
root_cert_path: Option<&Path>,
root_cert_path: Option<&CertificateInput>,
) -> Result<sqlx_rt::TlsConnector, Error> {
use sqlx_rt::{
fs,
native_tls::{Certificate, TlsConnector},
};
use sqlx_rt::native_tls::{Certificate, TlsConnector};
let mut builder = TlsConnector::builder();
builder
@ -88,7 +127,7 @@ async fn configure_tls_connector(
if !accept_invalid_certs {
if let Some(ca) = root_cert_path {
let data = fs::read(ca).await?;
let data = ca.data().await?;
let cert = Certificate::from_pem(&data)?;
builder.add_root_certificate(cert);

View File

@ -1,10 +1,10 @@
use crate::net::CertificateInput;
use rustls::{
Certificate, ClientConfig, RootCertStore, ServerCertVerified, ServerCertVerifier, TLSError,
WebPKIVerifier,
};
use sqlx_rt::fs;
use std::io::Cursor;
use std::sync::Arc;
use std::{io::Cursor, path::Path};
use webpki::DNSNameRef;
use crate::error::Error;
@ -12,7 +12,7 @@ use crate::error::Error;
pub async fn configure_tls_connector(
accept_invalid_certs: bool,
accept_invalid_hostnames: bool,
root_cert_path: Option<&Path>,
root_cert_path: Option<&CertificateInput>,
) -> Result<sqlx_rt::TlsConnector, Error> {
let mut config = ClientConfig::new();
@ -26,11 +26,12 @@ pub async fn configure_tls_connector(
.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
if let Some(ca) = root_cert_path {
let data = fs::read(ca).await?;
let data = ca.data().await?;
let mut cursor = Cursor::new(data);
config.root_store.add_pem_file(&mut cursor).map_err(|_| {
Error::Tls(format!("Invalid certificate file: {}", ca.display()).into())
})?;
config
.root_store
.add_pem_file(&mut cursor)
.map_err(|_| Error::Tls(format!("Invalid certificate {}", ca).into()))?;
}
if accept_invalid_hostnames {

View File

@ -70,7 +70,7 @@ async fn upgrade(stream: &mut PgStream, options: &PgConnectOptions) -> Result<bo
&options.host,
accept_invalid_certs,
accept_invalid_hostnames,
options.ssl_root_cert.as_deref(),
options.ssl_root_cert.as_ref(),
)
.await?;

View File

@ -4,7 +4,7 @@ use std::path::{Path, PathBuf};
mod connect;
mod parse;
mod ssl_mode;
use crate::connection::LogSettings;
use crate::{connection::LogSettings, net::CertificateInput};
pub use ssl_mode::PgSslMode;
/// Options and flags which can be used to configure a PostgreSQL connection.
@ -80,7 +80,7 @@ pub struct PgConnectOptions {
pub(crate) password: Option<String>,
pub(crate) database: Option<String>,
pub(crate) ssl_mode: PgSslMode,
pub(crate) ssl_root_cert: Option<PathBuf>,
pub(crate) ssl_root_cert: Option<CertificateInput>,
pub(crate) statement_cache_capacity: usize,
pub(crate) application_name: Option<String>,
pub(crate) log_settings: LogSettings,
@ -128,7 +128,7 @@ impl PgConnectOptions {
username: var("PGUSER").ok().unwrap_or_else(whoami::username),
password: var("PGPASSWORD").ok(),
database: var("PGDATABASE").ok(),
ssl_root_cert: var("PGSSLROOTCERT").ok().map(PathBuf::from),
ssl_root_cert: var("PGSSLROOTCERT").ok().map(CertificateInput::from),
ssl_mode: var("PGSSLMODE")
.ok()
.and_then(|v| v.parse().ok())
@ -265,7 +265,23 @@ impl PgConnectOptions {
/// .ssl_root_cert("./ca-certificate.crt");
/// ```
pub fn ssl_root_cert(mut self, cert: impl AsRef<Path>) -> Self {
self.ssl_root_cert = Some(cert.as_ref().to_path_buf());
self.ssl_root_cert = Some(CertificateInput::File(cert.as_ref().to_path_buf()));
self
}
/// Sets PEM encoded trusted SSL Certificate Authorities (CA).
///
/// # Example
///
/// ```rust
/// # use sqlx_core::postgres::{PgSslMode, PgConnectOptions};
/// let options = PgConnectOptions::new()
/// // Providing a CA certificate with less than VerifyCa is pointless
/// .ssl_mode(PgSslMode::VerifyCa)
/// .ssl_root_cert_from_pem(vec![]);
/// ```
pub fn ssl_root_cert_from_pem(mut self, pem_certificate: Vec<u8>) -> Self {
self.ssl_root_cert = Some(CertificateInput::Inline(pem_certificate));
self
}