mirror of
https://github.com/launchbadge/sqlx.git
synced 2025-10-03 15:55:45 +00:00
256 lines
7.4 KiB
Rust
256 lines
7.4 KiB
Rust
#![allow(dead_code)]
|
|
|
|
use std::io;
|
|
use std::ops::{Deref, DerefMut};
|
|
use std::path::Path;
|
|
use std::pin::Pin;
|
|
use std::task::{Context, Poll};
|
|
|
|
use sqlx_rt::{AsyncRead, AsyncWrite, TlsStream};
|
|
|
|
use crate::error::Error;
|
|
use std::mem::replace;
|
|
|
|
#[cfg(feature = "_tls-rustls")]
|
|
mod rustls;
|
|
|
|
pub enum MaybeTlsStream<S>
|
|
where
|
|
S: AsyncRead + AsyncWrite + Unpin,
|
|
{
|
|
Raw(S),
|
|
Tls(TlsStream<S>),
|
|
Upgrading,
|
|
}
|
|
|
|
impl<S> MaybeTlsStream<S>
|
|
where
|
|
S: AsyncRead + AsyncWrite + Unpin,
|
|
{
|
|
#[inline]
|
|
pub fn is_tls(&self) -> bool {
|
|
matches!(self, Self::Tls(_))
|
|
}
|
|
|
|
pub async fn upgrade(
|
|
&mut self,
|
|
host: &str,
|
|
accept_invalid_certs: bool,
|
|
accept_invalid_hostnames: bool,
|
|
root_cert_path: Option<&Path>,
|
|
) -> Result<(), Error> {
|
|
let connector = configure_tls_connector(
|
|
accept_invalid_certs,
|
|
accept_invalid_hostnames,
|
|
root_cert_path,
|
|
)
|
|
.await?;
|
|
|
|
let stream = match replace(self, MaybeTlsStream::Upgrading) {
|
|
MaybeTlsStream::Raw(stream) => stream,
|
|
|
|
MaybeTlsStream::Tls(_) => {
|
|
// ignore upgrade, we are already a TLS connection
|
|
return Ok(());
|
|
}
|
|
|
|
MaybeTlsStream::Upgrading => {
|
|
// we previously failed to upgrade and now hold no connection
|
|
// this should only happen from an internal misuse of this method
|
|
return Err(Error::Io(io::ErrorKind::ConnectionAborted.into()));
|
|
}
|
|
};
|
|
|
|
#[cfg(feature = "_tls-rustls")]
|
|
let host = webpki::DNSNameRef::try_from_ascii_str(host)?;
|
|
|
|
*self = MaybeTlsStream::Tls(connector.connect(host, stream).await?);
|
|
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
#[cfg(feature = "_tls-native-tls")]
|
|
async fn configure_tls_connector(
|
|
accept_invalid_certs: bool,
|
|
accept_invalid_hostnames: bool,
|
|
root_cert_path: Option<&Path>,
|
|
) -> Result<sqlx_rt::TlsConnector, Error> {
|
|
use sqlx_rt::{
|
|
fs,
|
|
native_tls::{Certificate, TlsConnector},
|
|
};
|
|
|
|
let mut builder = TlsConnector::builder();
|
|
builder
|
|
.danger_accept_invalid_certs(accept_invalid_certs)
|
|
.danger_accept_invalid_hostnames(accept_invalid_hostnames);
|
|
|
|
if !accept_invalid_certs {
|
|
if let Some(ca) = root_cert_path {
|
|
let data = fs::read(ca).await?;
|
|
let cert = Certificate::from_pem(&data)?;
|
|
|
|
builder.add_root_certificate(cert);
|
|
}
|
|
}
|
|
|
|
#[cfg(not(feature = "_rt-async-std"))]
|
|
let connector = builder.build()?.into();
|
|
|
|
#[cfg(feature = "_rt-async-std")]
|
|
let connector = builder.into();
|
|
|
|
Ok(connector)
|
|
}
|
|
|
|
#[cfg(feature = "_tls-rustls")]
|
|
use self::rustls::configure_tls_connector;
|
|
|
|
impl<S> AsyncRead for MaybeTlsStream<S>
|
|
where
|
|
S: Unpin + AsyncWrite + AsyncRead,
|
|
{
|
|
fn poll_read(
|
|
mut self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
buf: &mut [u8],
|
|
) -> Poll<io::Result<usize>> {
|
|
match &mut *self {
|
|
MaybeTlsStream::Raw(s) => Pin::new(s).poll_read(cx, buf),
|
|
MaybeTlsStream::Tls(s) => Pin::new(s).poll_read(cx, buf),
|
|
|
|
MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())),
|
|
}
|
|
}
|
|
|
|
#[cfg(any(feature = "_rt-actix", feature = "_rt-tokio"))]
|
|
fn poll_read_buf<B>(
|
|
mut self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
buf: &mut B,
|
|
) -> Poll<io::Result<usize>>
|
|
where
|
|
Self: Sized,
|
|
B: bytes::BufMut,
|
|
{
|
|
match &mut *self {
|
|
MaybeTlsStream::Raw(s) => Pin::new(s).poll_read_buf(cx, buf),
|
|
MaybeTlsStream::Tls(s) => Pin::new(s).poll_read_buf(cx, buf),
|
|
|
|
MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<S> AsyncWrite for MaybeTlsStream<S>
|
|
where
|
|
S: Unpin + AsyncWrite + AsyncRead,
|
|
{
|
|
fn poll_write(
|
|
mut self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
buf: &[u8],
|
|
) -> Poll<io::Result<usize>> {
|
|
match &mut *self {
|
|
MaybeTlsStream::Raw(s) => Pin::new(s).poll_write(cx, buf),
|
|
MaybeTlsStream::Tls(s) => Pin::new(s).poll_write(cx, buf),
|
|
|
|
MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())),
|
|
}
|
|
}
|
|
|
|
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
|
match &mut *self {
|
|
MaybeTlsStream::Raw(s) => Pin::new(s).poll_flush(cx),
|
|
MaybeTlsStream::Tls(s) => Pin::new(s).poll_flush(cx),
|
|
|
|
MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())),
|
|
}
|
|
}
|
|
|
|
#[cfg(any(feature = "_rt-actix", feature = "_rt-tokio"))]
|
|
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
|
match &mut *self {
|
|
MaybeTlsStream::Raw(s) => Pin::new(s).poll_shutdown(cx),
|
|
MaybeTlsStream::Tls(s) => Pin::new(s).poll_shutdown(cx),
|
|
|
|
MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())),
|
|
}
|
|
}
|
|
|
|
#[cfg(feature = "_rt-async-std")]
|
|
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
|
match &mut *self {
|
|
MaybeTlsStream::Raw(s) => Pin::new(s).poll_close(cx),
|
|
MaybeTlsStream::Tls(s) => Pin::new(s).poll_close(cx),
|
|
|
|
MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())),
|
|
}
|
|
}
|
|
|
|
#[cfg(any(feature = "_rt-actix", feature = "_rt-tokio"))]
|
|
fn poll_write_buf<B>(
|
|
mut self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
buf: &mut B,
|
|
) -> Poll<io::Result<usize>>
|
|
where
|
|
Self: Sized,
|
|
B: bytes::Buf,
|
|
{
|
|
match &mut *self {
|
|
MaybeTlsStream::Raw(s) => Pin::new(s).poll_write_buf(cx, buf),
|
|
MaybeTlsStream::Tls(s) => Pin::new(s).poll_write_buf(cx, buf),
|
|
|
|
MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<S> Deref for MaybeTlsStream<S>
|
|
where
|
|
S: Unpin + AsyncWrite + AsyncRead,
|
|
{
|
|
type Target = S;
|
|
|
|
fn deref(&self) -> &Self::Target {
|
|
match self {
|
|
MaybeTlsStream::Raw(s) => s,
|
|
|
|
#[cfg(feature = "_tls-rustls")]
|
|
MaybeTlsStream::Tls(s) => s.get_ref().0,
|
|
|
|
#[cfg(all(feature = "_rt-async-std", feature = "_tls-native-tls"))]
|
|
MaybeTlsStream::Tls(s) => s.get_ref(),
|
|
|
|
#[cfg(all(not(feature = "_rt-async-std"), feature = "_tls-native-tls"))]
|
|
MaybeTlsStream::Tls(s) => s.get_ref().get_ref().get_ref(),
|
|
|
|
MaybeTlsStream::Upgrading => panic!(io::Error::from(io::ErrorKind::ConnectionAborted)),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<S> DerefMut for MaybeTlsStream<S>
|
|
where
|
|
S: Unpin + AsyncWrite + AsyncRead,
|
|
{
|
|
fn deref_mut(&mut self) -> &mut Self::Target {
|
|
match self {
|
|
MaybeTlsStream::Raw(s) => s,
|
|
|
|
#[cfg(feature = "_tls-rustls")]
|
|
MaybeTlsStream::Tls(s) => s.get_mut().0,
|
|
|
|
#[cfg(all(feature = "_rt-async-std", feature = "_tls-native-tls"))]
|
|
MaybeTlsStream::Tls(s) => s.get_mut(),
|
|
|
|
#[cfg(all(not(feature = "_rt-async-std"), feature = "_tls-native-tls"))]
|
|
MaybeTlsStream::Tls(s) => s.get_mut().get_mut().get_mut(),
|
|
|
|
MaybeTlsStream::Upgrading => panic!(io::Error::from(io::ErrorKind::ConnectionAborted)),
|
|
}
|
|
}
|
|
}
|