mirror of
https://github.com/launchbadge/sqlx.git
synced 2026-04-18 20:55:45 +00:00
implement TLS support for Postgres
This commit is contained in:
3
.github/workflows/postgres.yml
vendored
3
.github/workflows/postgres.yml
vendored
@@ -49,6 +49,9 @@ jobs:
|
||||
|
||||
# -----------------------------------------------------
|
||||
|
||||
# Check that we build with TLS support (TODO: we need a postgres image with SSL certs to test)
|
||||
- run: cargo build -p sqlx-core --no-default-features 'postgres macros uuid chrono tls'
|
||||
|
||||
- run: cargo test -p sqlx --no-default-features --features 'postgres macros uuid chrono'
|
||||
env:
|
||||
DATABASE_URL: postgres://postgres:postgres@localhost:${{ job.services.postgres.ports[5432] }}/postgres
|
||||
|
||||
@@ -30,6 +30,7 @@ all-features = true
|
||||
[features]
|
||||
default = [ "macros" ]
|
||||
macros = [ "sqlx-macros", "proc-macro-hack" ]
|
||||
tls = ["sqlx-core/tls"]
|
||||
|
||||
# database
|
||||
postgres = [ "sqlx-core/postgres", "sqlx-macros/postgres" ]
|
||||
@@ -48,6 +49,7 @@ hex = "0.4.0"
|
||||
[dev-dependencies]
|
||||
anyhow = "1.0.26"
|
||||
futures = "0.3.1"
|
||||
env_logger = "0.7"
|
||||
async-std = { version = "1.4.0", features = [ "attributes" ] }
|
||||
dotenv = "0.15.0"
|
||||
|
||||
|
||||
@@ -20,8 +20,10 @@ default = []
|
||||
unstable = []
|
||||
postgres = [ "md-5", "sha2", "base64", "sha-1", "rand", "hmac" ]
|
||||
mysql = [ "sha-1", "sha2", "generic-array", "num-bigint", "base64", "digest", "rand" ]
|
||||
tls = ["async-native-tls"]
|
||||
|
||||
[dependencies]
|
||||
async-native-tls = { version = "0.3", optional = true }
|
||||
async-std = "1.4.0"
|
||||
async-stream = { version = "0.2.0", default-features = false }
|
||||
base64 = { version = "0.11.0", default-features = false, optional = true, features = [ "std" ] }
|
||||
|
||||
@@ -44,6 +44,9 @@ pub enum Error {
|
||||
/// [Pool::close] was called while we were waiting in [Pool::acquire].
|
||||
PoolClosed,
|
||||
|
||||
/// An error occurred during a TLS upgrade.
|
||||
TlsUpgrade(Box<dyn StdError + Send + Sync>),
|
||||
|
||||
Decode(DecodeError),
|
||||
|
||||
// TODO: Remove and replace with `#[non_exhaustive]` when possible
|
||||
@@ -62,6 +65,8 @@ impl StdError for Error {
|
||||
|
||||
Error::Decode(DecodeError::Other(error)) => Some(&**error),
|
||||
|
||||
Error::TlsUpgrade(error) => Some(&**error),
|
||||
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
@@ -100,6 +105,8 @@ impl Display for Error {
|
||||
|
||||
Error::PoolClosed => f.write_str("attempted to acquire a connection on a closed pool"),
|
||||
|
||||
Error::TlsUpgrade(ref err) => write!(f, "error during TLS upgrade: {}", err),
|
||||
|
||||
Error::__Nonexhaustive => unreachable!(),
|
||||
}
|
||||
}
|
||||
@@ -140,6 +147,21 @@ impl From<ProtocolError<'_>> for Error {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
impl From<async_native_tls::Error> for Error {
|
||||
#[inline]
|
||||
fn from(err: async_native_tls::Error) -> Self {
|
||||
Error::TlsUpgrade(err.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<TlsError<'_>> for Error {
|
||||
#[inline]
|
||||
fn from(err: TlsError<'_>) -> Self {
|
||||
Error::TlsUpgrade(err.args.to_string().into())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<T> for Error
|
||||
where
|
||||
T: 'static + DatabaseError,
|
||||
@@ -189,6 +211,15 @@ macro_rules! protocol_err (
|
||||
}
|
||||
);
|
||||
|
||||
pub(crate) struct TlsError<'a> {
|
||||
pub args: fmt::Arguments<'a>,
|
||||
}
|
||||
|
||||
#[allow(unused_macros)]
|
||||
macro_rules! tls_err {
|
||||
($($args:tt)*) => { crate::error::TlsError { args: format_args!($($args)*)} };
|
||||
}
|
||||
|
||||
#[allow(unused_macros)]
|
||||
macro_rules! impl_fmt_error {
|
||||
($err:ty) => {
|
||||
@@ -212,3 +243,4 @@ macro_rules! impl_fmt_error {
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -51,6 +51,12 @@ where
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn clear_bufs(&mut self) {
|
||||
self.rbuf_rindex = 0;
|
||||
self.rbuf_windex = 0;
|
||||
self.wbuf.clear();
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn consume(&mut self, cnt: usize) {
|
||||
self.rbuf_rindex += cnt;
|
||||
|
||||
@@ -5,11 +5,14 @@ mod buf;
|
||||
mod buf_mut;
|
||||
mod byte_str;
|
||||
|
||||
mod tls;
|
||||
|
||||
pub use self::{
|
||||
buf::{Buf, ToBuf},
|
||||
buf_mut::BufMut,
|
||||
buf_stream::BufStream,
|
||||
byte_str::ByteStr,
|
||||
tls::MaybeTlsStream
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
109
sqlx-core/src/io/tls.rs
Normal file
109
sqlx-core/src/io/tls.rs
Normal file
@@ -0,0 +1,109 @@
|
||||
use std::io::{IoSlice, IoSliceMut};
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
use async_std::io::{self, Read, Write};
|
||||
use async_std::net::{Shutdown, TcpStream};
|
||||
|
||||
use crate::url::Url;
|
||||
|
||||
use self::Inner::*;
|
||||
|
||||
pub struct MaybeTlsStream {
|
||||
inner: Inner,
|
||||
}
|
||||
|
||||
enum Inner {
|
||||
NotTls(TcpStream),
|
||||
#[cfg(feature = "tls")]
|
||||
Tls(async_native_tls::TlsStream<TcpStream>),
|
||||
#[cfg(feature = "tls")]
|
||||
Upgrading,
|
||||
}
|
||||
|
||||
impl MaybeTlsStream {
|
||||
pub async fn connect(url: &Url, default_port: u16) -> crate::Result<Self> {
|
||||
let conn = TcpStream::connect((url.host(), url.port(default_port))).await?;
|
||||
Ok(Self { inner: Inner::NotTls(conn) })
|
||||
}
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
pub async fn upgrade(&mut self, url: &Url, connector: async_native_tls::TlsConnector) -> crate::Result<()> {
|
||||
let conn = match std::mem::replace(&mut self.inner, Upgrading) {
|
||||
NotTls(conn) => conn,
|
||||
Tls(_) => return Err(tls_err!("connection already upgraded").into()),
|
||||
Upgrading => return Err(tls_err!("connection already failed to upgrade").into()),
|
||||
};
|
||||
|
||||
self.inner = Tls(connector.connect(url.host(), conn).await?);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
|
||||
match self.inner {
|
||||
NotTls(ref conn) => conn.shutdown(how),
|
||||
#[cfg(feature = "tls")]
|
||||
Tls(ref conn) => conn.get_ref().shutdown(how),
|
||||
#[cfg(feature = "tls")]
|
||||
// connection already closed
|
||||
Upgrading => Ok(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! forward_pin (
|
||||
($self:ident.$method:ident($($arg:ident),*)) => (
|
||||
match &mut $self.inner {
|
||||
NotTls(ref mut conn) => Pin::new(conn).$method($($arg),*),
|
||||
#[cfg(feature = "tls")]
|
||||
Tls(ref mut conn) => Pin::new(conn).$method($($arg),*),
|
||||
#[cfg(feature = "tls")]
|
||||
Upgrading => Err(io::Error::new(io::ErrorKind::Other, "connection broken; TLS upgrade failed")).into(),
|
||||
}
|
||||
)
|
||||
);
|
||||
|
||||
impl Read for MaybeTlsStream {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context,
|
||||
buf: &mut [u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
forward_pin!(self.poll_read(cx, buf))
|
||||
}
|
||||
|
||||
fn poll_read_vectored(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context,
|
||||
bufs: &mut [IoSliceMut],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
forward_pin!(self.poll_read_vectored(cx, bufs))
|
||||
}
|
||||
}
|
||||
|
||||
impl Write for MaybeTlsStream {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
forward_pin!(self.poll_write(cx, buf))
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
|
||||
forward_pin!(self.poll_flush(cx))
|
||||
}
|
||||
|
||||
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
|
||||
forward_pin!(self.poll_close(cx))
|
||||
}
|
||||
|
||||
fn poll_write_vectored(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context,
|
||||
bufs: &[IoSlice],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
forward_pin!(self.poll_write_vectored(cx, bufs))
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,8 @@
|
||||
use std::convert::TryInto;
|
||||
use std::path::Path;
|
||||
|
||||
use async_std::net::{Shutdown, TcpStream};
|
||||
use async_std::fs;
|
||||
use async_std::net::Shutdown;
|
||||
use byteorder::NetworkEndian;
|
||||
use futures_core::future::BoxFuture;
|
||||
|
||||
@@ -11,12 +13,15 @@ use crate::postgres::protocol::{
|
||||
self, hi, Authentication, Decode, Encode, Message, SaslInitialResponse, SaslResponse,
|
||||
StatementId,
|
||||
};
|
||||
use crate::io::{Buf, BufStream, MaybeTlsStream};
|
||||
use crate::postgres::PgError;
|
||||
use crate::postgres::protocol::{self, Decode, Encode, Message, StatementId};
|
||||
use crate::url::Url;
|
||||
use crate::Result;
|
||||
use hmac::{Hmac, Mac};
|
||||
use rand::Rng;
|
||||
use sha2::{Digest, Sha256};
|
||||
use async_native_tls::Certificate;
|
||||
|
||||
/// An asynchronous connection to a [Postgres] database.
|
||||
///
|
||||
@@ -24,7 +29,7 @@ use sha2::{Digest, Sha256};
|
||||
/// string, as documented at
|
||||
/// <https://www.postgresql.org/docs/12/libpq-connect.html#LIBPQ-CONNSTRING>
|
||||
pub struct PgConnection {
|
||||
pub(super) stream: BufStream<TcpStream>,
|
||||
pub(super) stream: BufStream<MaybeTlsStream>,
|
||||
|
||||
// Map of query to statement id
|
||||
pub(super) statement_cache: StatementCache<StatementId>,
|
||||
@@ -43,8 +48,43 @@ pub struct PgConnection {
|
||||
}
|
||||
|
||||
impl PgConnection {
|
||||
#[cfg(feature = "tls")]
|
||||
async fn try_ssl(&mut self, url: &Url, invalid_certs: bool, invalid_hostnames: bool) -> crate::Result<bool> {
|
||||
use async_native_tls::{TlsConnector, Certificate};
|
||||
use std::env;
|
||||
|
||||
protocol::SslRequest::encode(self.stream.buffer_mut());
|
||||
|
||||
self.stream.flush().await?;
|
||||
|
||||
match self.stream.peek(1).await? {
|
||||
Some(b"N") => return Ok(false),
|
||||
Some(b"S") => (),
|
||||
Some(other) => return Err(tls_err!("unexpected single-byte response: 0x{:02X}", other[0]).into()),
|
||||
None => return Err(tls_err!("server unexpectedly closed connection").into())
|
||||
}
|
||||
|
||||
let mut connector = TlsConnector::new()
|
||||
.danger_accept_invalid_certs(invalid_certs)
|
||||
.danger_accept_invalid_hostnames(invalid_hostnames);
|
||||
|
||||
if !invalid_certs {
|
||||
match read_root_certificate(&url).await {
|
||||
Ok(cert) => {
|
||||
connector = connector.add_root_certificate(cert);
|
||||
}
|
||||
Err(e) => log::warn!("failed to read Postgres root certificate: {}", e)
|
||||
}
|
||||
}
|
||||
|
||||
self.stream.clear_bufs();
|
||||
self.stream.stream.upgrade(url, connector).await?;
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
// https://www.postgresql.org/docs/12/protocol-flow.html#id-1.10.5.7.3
|
||||
async fn startup(&mut self, url: Url) -> Result<()> {
|
||||
async fn startup(&mut self, url: &Url) -> Result<()> {
|
||||
// Defaults to postgres@.../postgres
|
||||
let username = url.username().unwrap_or("postgres");
|
||||
let database = url.database().unwrap_or("postgres");
|
||||
@@ -83,7 +123,7 @@ impl PgConnection {
|
||||
protocol::PasswordMessage::ClearText(
|
||||
url.password().unwrap_or_default(),
|
||||
)
|
||||
.encode(self.stream.buffer_mut());
|
||||
.encode(self.stream.buffer_mut());
|
||||
|
||||
self.stream.flush().await?;
|
||||
}
|
||||
@@ -94,7 +134,7 @@ impl PgConnection {
|
||||
user: username,
|
||||
salt,
|
||||
}
|
||||
.encode(self.stream.buffer_mut());
|
||||
.encode(self.stream.buffer_mut());
|
||||
|
||||
self.stream.flush().await?;
|
||||
}
|
||||
@@ -137,7 +177,7 @@ impl PgConnection {
|
||||
"requires unimplemented authentication method: {:?}",
|
||||
auth
|
||||
)
|
||||
.into());
|
||||
.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -240,7 +280,8 @@ impl PgConnection {
|
||||
impl PgConnection {
|
||||
pub(super) async fn open(url: Result<Url>) -> Result<Self> {
|
||||
let url = url?;
|
||||
let stream = TcpStream::connect((url.host(), url.port(5432))).await?;
|
||||
|
||||
let stream = MaybeTlsStream::connect(&url, 5432).await?;
|
||||
let mut self_ = Self {
|
||||
stream: BufStream::new(stream),
|
||||
process_id: 0,
|
||||
@@ -251,7 +292,36 @@ impl PgConnection {
|
||||
ready: true,
|
||||
};
|
||||
|
||||
self_.startup(url).await?;
|
||||
let ssl_mode = url.get_param("sslmode").unwrap_or("prefer".into());
|
||||
|
||||
match &*ssl_mode {
|
||||
// TODO: on "allow" retry with TLS if startup fails
|
||||
"disable" | "allow" => (),
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
"prefer" => { self_.try_ssl(&url, true, true).await?; },
|
||||
|
||||
#[cfg(not(feature = "tls"))]
|
||||
"prefer" => log::info!("compiled without TLS, skipping upgrade"),
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
"require" | "verify-ca" | "verify-full" => if !self_.try_ssl(
|
||||
&url,
|
||||
ssl_mode == "require", // false for both verify-ca and verify-full
|
||||
ssl_mode != "verify-full" // false for only verify-full
|
||||
).await? {
|
||||
return Err(tls_err!("Postgres server does not support TLS").into())
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "tls"))]
|
||||
"require" | "verify-ca" | "verify-full" => return Err(
|
||||
tls_err!("sslmode {:?} unsupported; SQLx was compiled without `tls` feature",
|
||||
ssl_mode).into()
|
||||
),
|
||||
_ => return Err(tls_err!("unknown `sslmode` value: {:?}", ssl_mode).into()),
|
||||
}
|
||||
|
||||
self_.startup(&url).await?;
|
||||
|
||||
Ok(self_)
|
||||
}
|
||||
@@ -259,9 +329,9 @@ impl PgConnection {
|
||||
|
||||
impl Connection for PgConnection {
|
||||
fn open<T>(url: T) -> BoxFuture<'static, Result<Self>>
|
||||
where
|
||||
T: TryInto<Url, Error = crate::Error>,
|
||||
Self: Sized,
|
||||
where
|
||||
T: TryInto<Url, Error=crate::Error>,
|
||||
Self: Sized,
|
||||
{
|
||||
Box::pin(PgConnection::open(url.try_into()))
|
||||
}
|
||||
@@ -271,6 +341,26 @@ impl Connection for PgConnection {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
async fn read_root_certificate(url: &Url) -> crate::Result<async_native_tls::Certificate> {
|
||||
use std::env;
|
||||
|
||||
let root_cert_path = if let Some(path) = url.get_param("sslrootcert") {
|
||||
path.into()
|
||||
} else if let Ok(cert_path) = env::var("PGSSLROOTCERT"){
|
||||
cert_path
|
||||
} else if cfg!(windows) {
|
||||
let appdata = env::var("APPDATA").map_err(|_| tls_err!("APPDATA not set"))?;
|
||||
format!("{}\\postgresql\\root.crt", appdata)
|
||||
} else {
|
||||
let home = env::var("HOME").map_err(|_| tls_err!("HOME not set"))?;
|
||||
format!("{}/.postgresql/root.crt", home)
|
||||
};
|
||||
|
||||
let root_cert = async_std::fs::read(root_cert_path).await?;
|
||||
Ok(async_native_tls::Certificate::from_pem(&root_cert)?)
|
||||
}
|
||||
|
||||
static GS2_HEADER: &'static str = "n,,";
|
||||
static CHANNEL_ATTR: &'static str = "c";
|
||||
static USERNAME_ATTR: &'static str = "n";
|
||||
@@ -354,7 +444,7 @@ async fn sasl_auth<T: AsRef<str>>(conn: &mut PgConnection, username: T, password
|
||||
);
|
||||
|
||||
// AuthMessage := client-first-message-bare + "," + server-first-message + "," + client-final-message-without-proof
|
||||
let auth_message = format!("{client_first_message_bare},{server_first_message},{client_final_message_wo_proof}",
|
||||
let auth_message = format!("{client_first_message_bare},{server_first_message},{client_final_message_wo_proof}",
|
||||
client_first_message_bare = client_first_message_bare,
|
||||
server_first_message = server_first_message,
|
||||
client_final_message_wo_proof = client_final_message_wo_proof);
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
// the size of this module to exactly what is necessary.
|
||||
#![allow(unused)]
|
||||
|
||||
// REQUESTS
|
||||
mod bind;
|
||||
mod cancel_request;
|
||||
mod close;
|
||||
@@ -16,6 +17,7 @@ mod parse;
|
||||
mod password_message;
|
||||
mod query;
|
||||
mod sasl;
|
||||
mod ssl_request;
|
||||
mod startup_message;
|
||||
mod statement;
|
||||
mod sync;
|
||||
@@ -32,11 +34,13 @@ pub use parse::Parse;
|
||||
pub use password_message::PasswordMessage;
|
||||
pub use query::Query;
|
||||
pub use sasl::{hi, SaslInitialResponse, SaslResponse};
|
||||
pub use ssl_request::SslRequest;
|
||||
pub use startup_message::StartupMessage;
|
||||
pub use statement::StatementId;
|
||||
pub use sync::Sync;
|
||||
pub use terminate::Terminate;
|
||||
|
||||
// RESPONSES
|
||||
mod authentication;
|
||||
mod backend_key_data;
|
||||
mod command_complete;
|
||||
|
||||
25
sqlx-core/src/postgres/protocol/ssl_request.rs
Normal file
25
sqlx-core/src/postgres/protocol/ssl_request.rs
Normal file
@@ -0,0 +1,25 @@
|
||||
use crate::io::{Buf, BufMut};
|
||||
use byteorder::NetworkEndian;
|
||||
|
||||
pub struct SslRequest;
|
||||
|
||||
impl SslRequest {
|
||||
pub fn encode(buf: &mut Vec<u8>) {
|
||||
// packet length: 8 bytes including self
|
||||
buf.put_u32::<NetworkEndian>(8);
|
||||
// 1234 in high 16 bits, 5679 in low 16
|
||||
buf.put_u32::<NetworkEndian>(
|
||||
(1234 << 16) | 5679,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ssl_request() {
|
||||
use crate::io::Buf;
|
||||
|
||||
let mut buf = Vec::new();
|
||||
SslRequest::encode(&mut buf);
|
||||
|
||||
assert_eq!((&buf[..]).get_u32::<NetworkEndian>().unwrap(), 80877103);
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
use std::convert::{TryFrom, TryInto};
|
||||
use std::borrow::Cow;
|
||||
|
||||
pub struct Url(url::Url);
|
||||
|
||||
@@ -64,4 +65,14 @@ impl Url {
|
||||
Some(database)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_param(&self, key: &str) -> Option<Cow<str>> {
|
||||
self.0.query_pairs().find_map(|(key_, val)| {
|
||||
if key == key_ {
|
||||
Some(val)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -66,5 +66,7 @@ async fn it_remains_stable_issue_30() -> anyhow::Result<()> {
|
||||
}
|
||||
|
||||
async fn connect() -> anyhow::Result<PgConnection> {
|
||||
let _ = dotenv::dotenv();
|
||||
let _ = env_logger::try_init();
|
||||
Ok(PgConnection::open(dotenv::var("DATABASE_URL")?).await?)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user