#![allow(dead_code)] use std::io; use std::ops::{Deref, DerefMut}; use std::path::Path; use std::pin::Pin; use std::task::{Context, Poll}; use sqlx_rt::{AsyncRead, AsyncWrite, TlsStream}; use crate::error::Error; use std::mem::replace; #[cfg(feature = "_tls-rustls")] mod rustls; pub enum MaybeTlsStream where S: AsyncRead + AsyncWrite + Unpin, { Raw(S), Tls(TlsStream), Upgrading, } impl MaybeTlsStream where S: AsyncRead + AsyncWrite + Unpin, { #[inline] pub fn is_tls(&self) -> bool { matches!(self, Self::Tls(_)) } pub async fn upgrade( &mut self, host: &str, accept_invalid_certs: bool, accept_invalid_hostnames: bool, root_cert_path: Option<&Path>, ) -> Result<(), Error> { let connector = configure_tls_connector( accept_invalid_certs, accept_invalid_hostnames, root_cert_path, ) .await?; let stream = match replace(self, MaybeTlsStream::Upgrading) { MaybeTlsStream::Raw(stream) => stream, MaybeTlsStream::Tls(_) => { // ignore upgrade, we are already a TLS connection return Ok(()); } MaybeTlsStream::Upgrading => { // we previously failed to upgrade and now hold no connection // this should only happen from an internal misuse of this method return Err(Error::Io(io::ErrorKind::ConnectionAborted.into())); } }; #[cfg(feature = "_tls-rustls")] let host = webpki::DNSNameRef::try_from_ascii_str(host)?; *self = MaybeTlsStream::Tls(connector.connect(host, stream).await?); Ok(()) } } #[cfg(feature = "_tls-native-tls")] async fn configure_tls_connector( accept_invalid_certs: bool, accept_invalid_hostnames: bool, root_cert_path: Option<&Path>, ) -> Result { use sqlx_rt::{ fs, native_tls::{Certificate, TlsConnector}, }; let mut builder = TlsConnector::builder(); builder .danger_accept_invalid_certs(accept_invalid_certs) .danger_accept_invalid_hostnames(accept_invalid_hostnames); if !accept_invalid_certs { if let Some(ca) = root_cert_path { let data = fs::read(ca).await?; let cert = Certificate::from_pem(&data)?; builder.add_root_certificate(cert); } } #[cfg(not(feature = "_rt-async-std"))] let connector = builder.build()?.into(); #[cfg(feature = "_rt-async-std")] let connector = builder.into(); Ok(connector) } #[cfg(feature = "_tls-rustls")] use self::rustls::configure_tls_connector; impl AsyncRead for MaybeTlsStream where S: Unpin + AsyncWrite + AsyncRead, { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll> { match &mut *self { MaybeTlsStream::Raw(s) => Pin::new(s).poll_read(cx, buf), MaybeTlsStream::Tls(s) => Pin::new(s).poll_read(cx, buf), MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())), } } #[cfg(any(feature = "_rt-actix", feature = "_rt-tokio"))] fn poll_read_buf( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut B, ) -> Poll> where Self: Sized, B: bytes::BufMut, { match &mut *self { MaybeTlsStream::Raw(s) => Pin::new(s).poll_read_buf(cx, buf), MaybeTlsStream::Tls(s) => Pin::new(s).poll_read_buf(cx, buf), MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())), } } } impl AsyncWrite for MaybeTlsStream where S: Unpin + AsyncWrite + AsyncRead, { fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { match &mut *self { MaybeTlsStream::Raw(s) => Pin::new(s).poll_write(cx, buf), MaybeTlsStream::Tls(s) => Pin::new(s).poll_write(cx, buf), MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())), } } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match &mut *self { MaybeTlsStream::Raw(s) => Pin::new(s).poll_flush(cx), MaybeTlsStream::Tls(s) => Pin::new(s).poll_flush(cx), MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())), } } #[cfg(any(feature = "_rt-actix", feature = "_rt-tokio"))] fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match &mut *self { MaybeTlsStream::Raw(s) => Pin::new(s).poll_shutdown(cx), MaybeTlsStream::Tls(s) => Pin::new(s).poll_shutdown(cx), MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())), } } #[cfg(feature = "_rt-async-std")] fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match &mut *self { MaybeTlsStream::Raw(s) => Pin::new(s).poll_close(cx), MaybeTlsStream::Tls(s) => Pin::new(s).poll_close(cx), MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())), } } #[cfg(any(feature = "_rt-actix", feature = "_rt-tokio"))] fn poll_write_buf( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut B, ) -> Poll> where Self: Sized, B: bytes::Buf, { match &mut *self { MaybeTlsStream::Raw(s) => Pin::new(s).poll_write_buf(cx, buf), MaybeTlsStream::Tls(s) => Pin::new(s).poll_write_buf(cx, buf), MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())), } } } impl Deref for MaybeTlsStream where S: Unpin + AsyncWrite + AsyncRead, { type Target = S; fn deref(&self) -> &Self::Target { match self { MaybeTlsStream::Raw(s) => s, #[cfg(feature = "_tls-rustls")] MaybeTlsStream::Tls(s) => s.get_ref().0, #[cfg(all(feature = "_rt-async-std", feature = "_tls-native-tls"))] MaybeTlsStream::Tls(s) => s.get_ref(), #[cfg(all(not(feature = "_rt-async-std"), feature = "_tls-native-tls"))] MaybeTlsStream::Tls(s) => s.get_ref().get_ref().get_ref(), MaybeTlsStream::Upgrading => panic!(io::Error::from(io::ErrorKind::ConnectionAborted)), } } } impl DerefMut for MaybeTlsStream where S: Unpin + AsyncWrite + AsyncRead, { fn deref_mut(&mut self) -> &mut Self::Target { match self { MaybeTlsStream::Raw(s) => s, #[cfg(feature = "_tls-rustls")] MaybeTlsStream::Tls(s) => s.get_mut().0, #[cfg(all(feature = "_rt-async-std", feature = "_tls-native-tls"))] MaybeTlsStream::Tls(s) => s.get_mut(), #[cfg(all(not(feature = "_rt-async-std"), feature = "_tls-native-tls"))] MaybeTlsStream::Tls(s) => s.get_mut().get_mut().get_mut(), MaybeTlsStream::Upgrading => panic!(io::Error::from(io::ErrorKind::ConnectionAborted)), } } }