mirror of
https://github.com/launchbadge/sqlx.git
synced 2025-10-02 15:25:32 +00:00
postgres: Add unix domain socket support
This commit is contained in:
parent
49f15713d6
commit
5628658d3f
@ -22,7 +22,7 @@ all-type = ["bigdecimal", "json", "time", "chrono", "ipnetwork", "uuid"]
|
||||
# we need a feature which activates `num-bigint` as well because
|
||||
# `bigdecimal` uses types from it but does not reexport (tsk tsk)
|
||||
bigdecimal = ["bigdecimal_", "num-bigint"]
|
||||
postgres = [ "md-5", "sha2", "base64", "sha-1", "rand", "hmac", "futures-channel/sink", "futures-util/sink" ]
|
||||
postgres = [ "md-5", "sha2", "base64", "sha-1", "rand", "hmac", "futures-channel/sink", "futures-util/sink", "tokio/uds" ]
|
||||
json = ["serde", "serde_json"]
|
||||
mysql = [ "sha-1", "sha2", "generic-array", "num-bigint", "base64", "digest", "rand" ]
|
||||
sqlite = [ "libsqlite3-sys" ]
|
||||
|
@ -13,6 +13,8 @@ pub struct MaybeTlsStream {
|
||||
|
||||
enum Inner {
|
||||
NotTls(TcpStream),
|
||||
#[cfg(all(feature = "postgres", unix))]
|
||||
UnixStream(crate::runtime::UnixStream),
|
||||
#[cfg(feature = "tls")]
|
||||
Tls(async_native_tls::TlsStream<TcpStream>),
|
||||
#[cfg(feature = "tls")]
|
||||
@ -20,6 +22,13 @@ enum Inner {
|
||||
}
|
||||
|
||||
impl MaybeTlsStream {
|
||||
#[cfg(all(feature = "postgres", unix))]
|
||||
pub async fn connect_uds<S: AsRef<std::ffi::OsStr>>(p: S) -> crate::Result<Self> {
|
||||
let conn = crate::runtime::UnixStream::connect(p.as_ref()).await?;
|
||||
Ok(Self {
|
||||
inner: Inner::UnixStream(conn),
|
||||
})
|
||||
}
|
||||
pub async fn connect(host: &str, port: u16) -> crate::Result<Self> {
|
||||
let conn = TcpStream::connect((host, port)).await?;
|
||||
Ok(Self {
|
||||
@ -31,6 +40,8 @@ impl MaybeTlsStream {
|
||||
pub fn is_tls(&self) -> bool {
|
||||
match self.inner {
|
||||
Inner::NotTls(_) => false,
|
||||
#[cfg(all(feature = "postgres", unix))]
|
||||
Inner::UnixStream(_) => false,
|
||||
#[cfg(feature = "tls")]
|
||||
Inner::Tls(_) => true,
|
||||
#[cfg(feature = "tls")]
|
||||
@ -47,6 +58,10 @@ impl MaybeTlsStream {
|
||||
) -> crate::Result<()> {
|
||||
let conn = match std::mem::replace(&mut self.inner, Upgrading) {
|
||||
NotTls(conn) => conn,
|
||||
#[cfg(all(feature = "postgres", unix))]
|
||||
UnixStream(_) => {
|
||||
return Err(tls_err!("TLS is not supported with unix domain sockets").into())
|
||||
}
|
||||
Tls(_) => return Err(tls_err!("connection already upgraded").into()),
|
||||
Upgrading => return Err(tls_err!("connection already failed to upgrade").into()),
|
||||
};
|
||||
@ -59,6 +74,8 @@ impl MaybeTlsStream {
|
||||
pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
|
||||
match self.inner {
|
||||
NotTls(ref conn) => conn.shutdown(how),
|
||||
#[cfg(all(feature = "postgres", unix))]
|
||||
UnixStream(ref conn) => conn.shutdown(how),
|
||||
#[cfg(feature = "tls")]
|
||||
Tls(ref conn) => conn.get_ref().shutdown(how),
|
||||
#[cfg(feature = "tls")]
|
||||
@ -72,6 +89,8 @@ macro_rules! forward_pin (
|
||||
($self:ident.$method:ident($($arg:ident),*)) => (
|
||||
match &mut $self.inner {
|
||||
NotTls(ref mut conn) => Pin::new(conn).$method($($arg),*),
|
||||
#[cfg(all(feature = "postgres", unix))]
|
||||
UnixStream(ref mut conn) => Pin::new(conn).$method($($arg),*),
|
||||
#[cfg(feature = "tls")]
|
||||
Tls(ref mut conn) => Pin::new(conn).$method($($arg),*),
|
||||
#[cfg(feature = "tls")]
|
||||
|
@ -23,9 +23,15 @@ pub struct PgStream {
|
||||
|
||||
impl PgStream {
|
||||
pub(super) async fn new(url: &Url) -> crate::Result<Self> {
|
||||
let host = url.host().unwrap_or("localhost");
|
||||
let host = url.host();
|
||||
let port = url.port(5432);
|
||||
let stream = MaybeTlsStream::connect(host, port).await?;
|
||||
#[cfg(unix)]
|
||||
let stream = match host {
|
||||
Some(host) => MaybeTlsStream::connect(host, port).await?,
|
||||
None => MaybeTlsStream::connect_uds(format!("/var/run/postgresql/.s.PGSQL.{}", port)).await?,
|
||||
};
|
||||
#[cfg(not(unix))]
|
||||
let stream = MaybeTlsStream::connect(host.unwrap_or("localhost"), port).await?;
|
||||
|
||||
Ok(Self {
|
||||
notifications: None,
|
||||
|
@ -17,6 +17,9 @@ pub(crate) use async_std::{
|
||||
task::spawn,
|
||||
};
|
||||
|
||||
#[cfg(all(feature = "runtime-async-std", feature = "postgres", unix))]
|
||||
pub(crate) use async_std::os::unix::net::UnixStream;
|
||||
|
||||
#[cfg(feature = "runtime-tokio")]
|
||||
pub(crate) use tokio::{
|
||||
fs,
|
||||
@ -26,3 +29,6 @@ pub(crate) use tokio::{
|
||||
time::delay_for as sleep,
|
||||
time::timeout,
|
||||
};
|
||||
|
||||
#[cfg(all(feature = "runtime-tokio", feature = "postgres", unix))]
|
||||
pub(crate) use tokio::net::UnixStream;
|
||||
|
Loading…
x
Reference in New Issue
Block a user