diff --git a/sqlx-core/src/common/statement_cache.rs b/sqlx-core/src/common/statement_cache.rs index 9fb6fb150..e6c8223f7 100644 --- a/sqlx-core/src/common/statement_cache.rs +++ b/sqlx-core/src/common/statement_cache.rs @@ -27,7 +27,7 @@ impl StatementCache { pub fn insert(&mut self, k: &str, v: T) -> Option { let mut lru_item = None; - if self.inner.capacity() == self.len() && !self.inner.contains_key(k) { + if self.capacity() == self.len() && !self.contains_key(k) { lru_item = self.remove_lru(); } else if self.contains_key(k) { lru_item = self.inner.remove(k); @@ -49,7 +49,7 @@ impl StatementCache { } /// Clear all cached statements from the cache. - #[cfg(any(feature = "sqlite"))] + #[cfg(feature = "sqlite")] pub fn clear(&mut self) { self.inner.clear(); } diff --git a/sqlx-core/src/mysql/connection/mod.rs b/sqlx-core/src/mysql/connection/mod.rs index 0e8ef3873..5c46473fd 100644 --- a/sqlx-core/src/mysql/connection/mod.rs +++ b/sqlx-core/src/mysql/connection/mod.rs @@ -1,5 +1,4 @@ use std::fmt::{self, Debug, Formatter}; -use std::net::Shutdown; use std::sync::Arc; use futures_core::future::BoxFuture; @@ -57,7 +56,7 @@ impl Connection for MySqlConnection { fn close(mut self) -> BoxFuture<'static, Result<(), Error>> { Box::pin(async move { self.stream.send_packet(Quit).await?; - self.stream.shutdown(Shutdown::Both)?; + self.stream.shutdown()?; Ok(()) }) diff --git a/sqlx-core/src/mysql/connection/stream.rs b/sqlx-core/src/mysql/connection/stream.rs index 51d3d48ec..e9a9317f2 100644 --- a/sqlx-core/src/mysql/connection/stream.rs +++ b/sqlx-core/src/mysql/connection/stream.rs @@ -1,7 +1,6 @@ use std::ops::{Deref, DerefMut}; use bytes::{Buf, Bytes}; -use sqlx_rt::TcpStream; use crate::error::Error; use crate::io::{BufStream, Decode, Encode}; @@ -9,10 +8,10 @@ use crate::mysql::io::MySqlBufExt; use crate::mysql::protocol::response::{EofPacket, ErrPacket, OkPacket, Status}; use crate::mysql::protocol::{Capabilities, Packet}; use crate::mysql::{MySqlConnectOptions, MySqlDatabaseError}; -use crate::net::MaybeTlsStream; +use crate::net::{MaybeTlsStream, Socket}; pub struct MySqlStream { - stream: BufStream>, + stream: BufStream>, pub(super) capabilities: Capabilities, pub(crate) sequence_id: u8, pub(crate) busy: Busy, @@ -31,7 +30,10 @@ pub(crate) enum Busy { impl MySqlStream { pub(super) async fn connect(options: &MySqlConnectOptions) -> Result { - let stream = TcpStream::connect((&*options.host, options.port)).await?; + let socket = match options.socket { + Some(ref path) => Socket::connect_uds(path).await?, + None => Socket::connect(&options.host, options.port).await?, + }; let mut capabilities = Capabilities::PROTOCOL_41 | Capabilities::IGNORE_SPACE @@ -54,7 +56,7 @@ impl MySqlStream { busy: Busy::NotBusy, capabilities, sequence_id: 0, - stream: BufStream::new(MaybeTlsStream::Raw(stream)), + stream: BufStream::new(MaybeTlsStream::Raw(socket)), }) } @@ -178,7 +180,7 @@ impl MySqlStream { } impl Deref for MySqlStream { - type Target = BufStream>; + type Target = BufStream>; fn deref(&self) -> &Self::Target { &self.stream diff --git a/sqlx-core/src/mysql/options.rs b/sqlx-core/src/mysql/options.rs index 35eaad33a..ee25e6902 100644 --- a/sqlx-core/src/mysql/options.rs +++ b/sqlx-core/src/mysql/options.rs @@ -75,6 +75,7 @@ impl FromStr for MySqlSslMode { /// | `ssl-mode` | `PREFERRED` | Determines whether or with what priority a secure SSL TCP/IP connection will be negotiated. See [`MySqlSslMode`]. | /// | `ssl-ca` | `None` | Sets the name of a file containing a list of trusted SSL Certificate Authorities. | /// | `statement-cache-capacity` | `100` | The maximum number of prepared statements stored in the cache. Set to `0` to disable. | +/// | `socket` | `None` | Path to the unix domain socket, which will be used instead of TCP if set. | /// /// # Example /// @@ -106,6 +107,7 @@ impl FromStr for MySqlSslMode { pub struct MySqlConnectOptions { pub(crate) host: String, pub(crate) port: u16, + pub(crate) socket: Option, pub(crate) username: String, pub(crate) password: Option, pub(crate) database: Option, @@ -126,6 +128,7 @@ impl MySqlConnectOptions { Self { port: 3306, host: String::from("localhost"), + socket: None, username: String::from("root"), password: None, database: None, @@ -152,6 +155,15 @@ impl MySqlConnectOptions { self } + /// Pass a path to a Unix socket. This changes the connection stream from + /// TCP to UDS. + /// + /// By default set to `None`. + pub fn socket(mut self, path: impl AsRef) -> Self { + self.socket = Some(path.as_ref().to_path_buf()); + self + } + /// Sets the username to connect as. pub fn username(mut self, username: &str) -> Self { self.username = username.to_owned(); @@ -258,6 +270,10 @@ impl FromStr for MySqlConnectOptions { options = options.statement_cache_capacity(value.parse()?); } + "socket" => { + options = options.socket(&*value); + } + _ => {} } } diff --git a/sqlx-core/src/net/socket.rs b/sqlx-core/src/net/socket.rs index fbfc7cf1e..7850738fa 100644 --- a/sqlx-core/src/net/socket.rs +++ b/sqlx-core/src/net/socket.rs @@ -2,6 +2,7 @@ use std::io; use std::net::Shutdown; +use std::path::Path; use std::pin::Pin; use std::task::{Context, Poll}; @@ -26,14 +27,27 @@ impl Socket { if host.starts_with('/') { // if the host starts with a forward slash, assume that this is a request // to connect to a local socket - sqlx_rt::UnixStream::connect(format!("{}/.s.PGSQL.{}", host, port)) - .await - .map(Socket::Unix) + Self::connect_uds(&format!("{}/.s.PGSQL.{}", host, port)).await } else { TcpStream::connect((host, port)).await.map(Socket::Tcp) } } + #[cfg(unix)] + pub async fn connect_uds(path: impl AsRef) -> io::Result { + sqlx_rt::UnixStream::connect(path.as_ref()) + .await + .map(Socket::Unix) + } + + #[cfg(not(unix))] + pub async fn connect_uds(_: impl AsRef) -> io::Result { + Err(io::Error( + io::ErrorKind::Other, + "Unix domain sockets are not supported outside Unix platforms.", + )) + } + pub fn shutdown(&self) -> io::Result<()> { match self { Socket::Tcp(s) => s.shutdown(Shutdown::Both), diff --git a/sqlx-core/src/postgres/connection/stream.rs b/sqlx-core/src/postgres/connection/stream.rs index 52fc38358..ea35d1f68 100644 --- a/sqlx-core/src/postgres/connection/stream.rs +++ b/sqlx-core/src/postgres/connection/stream.rs @@ -31,9 +31,15 @@ pub struct PgStream { impl PgStream { pub(super) async fn connect(options: &PgConnectOptions) -> Result { - let inner = BufStream::new(MaybeTlsStream::Raw( - Socket::connect(&options.host, options.port).await?, - )); + let socket = match options.socket { + Some(ref path) => { + Socket::connect_uds(&format!("{}/.s.PGSQL.{}", path.display(), options.port)) + .await? + } + None => Socket::connect(&options.host, options.port).await?, + }; + + let inner = BufStream::new(MaybeTlsStream::Raw(socket)); Ok(Self { inner, diff --git a/sqlx-core/src/postgres/options.rs b/sqlx-core/src/postgres/options.rs index d1640fa95..bd11f13ae 100644 --- a/sqlx-core/src/postgres/options.rs +++ b/sqlx-core/src/postgres/options.rs @@ -76,7 +76,7 @@ impl FromStr for PgSslMode { /// | `sslmode` | `prefer` | Determines whether or with what priority a secure SSL TCP/IP connection will be negotiated. See [`PgSqlSslMode`]. | /// | `sslrootcert` | `None` | Sets the name of a file containing a list of trusted SSL Certificate Authorities. | /// | `statement-cache-capacity` | `100` | The maximum number of prepared statements stored in the cache. Set to `0` to disable. | -/// +/// | `host` | `None` | Path to the directory containing a PostgreSQL unix domain socket, which will be used instead of TCP if set. | /// /// The URI scheme designator can be either `postgresql://` or `postgres://`. /// Each of the URI parts is optional. @@ -121,6 +121,7 @@ impl FromStr for PgSslMode { pub struct PgConnectOptions { pub(crate) host: String, pub(crate) port: u16, + pub(crate) socket: Option, pub(crate) username: String, pub(crate) password: Option, pub(crate) database: Option, @@ -166,6 +167,7 @@ impl PgConnectOptions { PgConnectOptions { port, host, + socket: None, username: var("PGUSER").ok().unwrap_or_else(whoami::username), password: var("PGPASSWORD").ok(), database: var("PGDATABASE").ok(), @@ -215,6 +217,25 @@ impl PgConnectOptions { self } + /// Sets a custom path to a directory containing a unix domain socket, + /// switching the connection method from TCP to the corresponding socket. + /// + /// By default set to `None`. + #[cfg(unix)] + pub fn socket(mut self, path: impl AsRef) -> Self { + self.socket = Some(path.as_ref().to_path_buf()); + self + } + + /// Sets a custom path to a directory containing a unix domain socket, + /// switching the connection method from TCP to the corresponding socket. + /// + /// By default set to `None`. + #[cfg(not(unix))] + pub fn socket(mut self, _: impl AsRef) -> Self { + self + } + /// Sets the username to connect as. /// /// Defaults to be the same as the operating system name of @@ -373,6 +394,10 @@ impl FromStr for PgConnectOptions { options = options.statement_cache_capacity(value.parse()?); } + "host" => { + options = options.socket(&*value); + } + _ => {} } }