Move pg-specific socket options to its options.

Makes tcp connections explicit.
This commit is contained in:
Julius de Bruijn 2020-06-26 10:54:52 +02:00 committed by Ryan Leckey
parent 71ebeb9cc3
commit 2115d02cb0
4 changed files with 21 additions and 20 deletions

View File

@ -32,7 +32,7 @@ impl MySqlStream {
pub(super) async fn connect(options: &MySqlConnectOptions) -> Result<Self, Error> {
let socket = match options.socket {
Some(ref path) => Socket::connect_uds(path).await?,
None => Socket::connect(&options.host, options.port).await?,
None => Socket::connect_tcp(&options.host, options.port).await?,
};
let mut capabilities = Capabilities::PROTOCOL_41

View File

@ -17,22 +17,10 @@ pub enum Socket {
}
impl Socket {
#[cfg(not(unix))]
pub async fn connect(host: &str, port: u16) -> io::Result<Self> {
pub async fn connect_tcp(host: &str, port: u16) -> io::Result<Self> {
TcpStream::connect((host, port)).await.map(Socket::Tcp)
}
#[cfg(unix)]
pub async fn connect(host: &str, port: u16) -> io::Result<Self> {
if host.starts_with('/') {
// if the host starts with a forward slash, assume that this is a request
// to connect to a local socket
Self::connect_uds(&format!("{}/.s.PGSQL.{}", host, port)).await
} else {
TcpStream::connect((host, port)).await.map(Socket::Tcp)
}
}
#[cfg(unix)]
pub async fn connect_uds(path: impl AsRef<Path>) -> io::Result<Self> {
sqlx_rt::UnixStream::connect(path.as_ref())

View File

@ -31,12 +31,9 @@ pub struct PgStream {
impl PgStream {
pub(super) async fn connect(options: &PgConnectOptions) -> Result<Self, Error> {
let socket = match options.socket {
Some(ref path) => {
Socket::connect_uds(&format!("{}/.s.PGSQL.{}", path.display(), options.port))
.await?
}
None => Socket::connect(&options.host, options.port).await?,
let socket = match options.fetch_socket() {
Some(ref path) => Socket::connect_uds(path).await?,
None => Socket::connect_tcp(&options.host, options.port).await?,
};
let inner = BufStream::new(MaybeTlsStream::Raw(socket));

View File

@ -320,6 +320,22 @@ impl PgConnectOptions {
self.statement_cache_capacity = capacity;
self
}
/// We try using a socket if hostname starts with `/` or if socket parameter
/// is specified.
pub(crate) fn fetch_socket(&self) -> Option<String> {
match self.socket {
Some(ref socket) => {
let full_path = format!("{}/.s.PGSQL.{}", socket.display(), self.port);
Some(full_path)
}
None if self.host.starts_with('/') => {
let full_path = format!("{}/.s.PGSQL.{}", self.host, self.port);
Some(full_path)
}
_ => None,
}
}
}
fn default_host(port: u16) -> String {