diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index c2591fdb..3215cb32 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -22,7 +22,7 @@ all-type = ["bigdecimal", "json", "time", "chrono", "ipnetwork", "uuid"] # we need a feature which activates `num-bigint` as well because # `bigdecimal` uses types from it but does not reexport (tsk tsk) bigdecimal = ["bigdecimal_", "num-bigint"] -postgres = [ "md-5", "sha2", "base64", "sha-1", "rand", "hmac", "futures-channel/sink", "futures-util/sink" ] +postgres = [ "md-5", "sha2", "base64", "sha-1", "rand", "hmac", "futures-channel/sink", "futures-util/sink", "tokio/uds" ] json = ["serde", "serde_json"] mysql = [ "sha-1", "sha2", "generic-array", "num-bigint", "base64", "digest", "rand" ] sqlite = [ "libsqlite3-sys" ] diff --git a/sqlx-core/src/io/tls.rs b/sqlx-core/src/io/tls.rs index a81c4f7d..9e0e032e 100644 --- a/sqlx-core/src/io/tls.rs +++ b/sqlx-core/src/io/tls.rs @@ -13,6 +13,8 @@ pub struct MaybeTlsStream { enum Inner { NotTls(TcpStream), + #[cfg(all(feature = "postgres", unix))] + UnixStream(crate::runtime::UnixStream), #[cfg(feature = "tls")] Tls(async_native_tls::TlsStream), #[cfg(feature = "tls")] @@ -20,6 +22,13 @@ enum Inner { } impl MaybeTlsStream { + #[cfg(all(feature = "postgres", unix))] + pub async fn connect_uds>(p: S) -> crate::Result { + let conn = crate::runtime::UnixStream::connect(p.as_ref()).await?; + Ok(Self { + inner: Inner::UnixStream(conn), + }) + } pub async fn connect(host: &str, port: u16) -> crate::Result { let conn = TcpStream::connect((host, port)).await?; Ok(Self { @@ -31,6 +40,8 @@ impl MaybeTlsStream { pub fn is_tls(&self) -> bool { match self.inner { Inner::NotTls(_) => false, + #[cfg(all(feature = "postgres", unix))] + Inner::UnixStream(_) => false, #[cfg(feature = "tls")] Inner::Tls(_) => true, #[cfg(feature = "tls")] @@ -47,6 +58,10 @@ impl MaybeTlsStream { ) -> crate::Result<()> { let conn = match std::mem::replace(&mut self.inner, Upgrading) { NotTls(conn) => conn, + #[cfg(all(feature = "postgres", unix))] + UnixStream(_) => { + return Err(tls_err!("TLS is not supported with unix domain sockets").into()) + } Tls(_) => return Err(tls_err!("connection already upgraded").into()), Upgrading => return Err(tls_err!("connection already failed to upgrade").into()), }; @@ -59,6 +74,8 @@ impl MaybeTlsStream { pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { match self.inner { NotTls(ref conn) => conn.shutdown(how), + #[cfg(all(feature = "postgres", unix))] + UnixStream(ref conn) => conn.shutdown(how), #[cfg(feature = "tls")] Tls(ref conn) => conn.get_ref().shutdown(how), #[cfg(feature = "tls")] @@ -72,6 +89,8 @@ macro_rules! forward_pin ( ($self:ident.$method:ident($($arg:ident),*)) => ( match &mut $self.inner { NotTls(ref mut conn) => Pin::new(conn).$method($($arg),*), + #[cfg(all(feature = "postgres", unix))] + UnixStream(ref mut conn) => Pin::new(conn).$method($($arg),*), #[cfg(feature = "tls")] Tls(ref mut conn) => Pin::new(conn).$method($($arg),*), #[cfg(feature = "tls")] diff --git a/sqlx-core/src/postgres/stream.rs b/sqlx-core/src/postgres/stream.rs index 6beb5ee2..52ca2079 100644 --- a/sqlx-core/src/postgres/stream.rs +++ b/sqlx-core/src/postgres/stream.rs @@ -23,9 +23,15 @@ pub struct PgStream { impl PgStream { pub(super) async fn new(url: &Url) -> crate::Result { - let host = url.host().unwrap_or("localhost"); + let host = url.host(); let port = url.port(5432); - let stream = MaybeTlsStream::connect(host, port).await?; + #[cfg(unix)] + let stream = match host { + Some(host) => MaybeTlsStream::connect(host, port).await?, + None => MaybeTlsStream::connect_uds(format!("/var/run/postgresql/.s.PGSQL.{}", port)).await?, + }; + #[cfg(not(unix))] + let stream = MaybeTlsStream::connect(host.unwrap_or("localhost"), port).await?; Ok(Self { notifications: None, diff --git a/sqlx-core/src/runtime.rs b/sqlx-core/src/runtime.rs index e76becc8..49e9c940 100644 --- a/sqlx-core/src/runtime.rs +++ b/sqlx-core/src/runtime.rs @@ -17,6 +17,9 @@ pub(crate) use async_std::{ task::spawn, }; +#[cfg(all(feature = "runtime-async-std", feature = "postgres", unix))] +pub(crate) use async_std::os::unix::net::UnixStream; + #[cfg(feature = "runtime-tokio")] pub(crate) use tokio::{ fs, @@ -26,3 +29,6 @@ pub(crate) use tokio::{ time::delay_for as sleep, time::timeout, }; + +#[cfg(all(feature = "runtime-tokio", feature = "postgres", unix))] +pub(crate) use tokio::net::UnixStream;