feat(postgres): add PgConnectOptions

This commit is contained in:
Ryan Leckey
2021-03-07 10:28:51 -08:00
parent baa63d33e1
commit 5a5ce82946
12 changed files with 441 additions and 23 deletions

View File

@@ -30,6 +30,7 @@ atoi = "0.4.0"
sqlx-core = { version = "0.6.0-pre", path = "../sqlx-core" }
futures-util = { version = "0.3.8", optional = true }
log = "0.4.11"
either = "1.6.1"
bytestring = "1.0.0"
url = "2.2.0"
percent-encoding = "2.1.0"

View File

@@ -30,7 +30,7 @@ mod column;
mod database;
// mod error;
// mod io;
// mod options;
mod options;
mod output;
mod protocol;
mod query_result;
@@ -49,7 +49,7 @@ pub use column::PgColumn;
// pub use connection::PgConnection;
pub use database::Postgres;
// pub use error::PgDatabaseError;
// pub use options::PgConnectOptions;
pub use options::PgConnectOptions;
pub use output::PgOutput;
pub use query_result::PgQueryResult;
pub use raw_value::{PgRawValue, PgRawValueFormat};

View File

@@ -0,0 +1,53 @@
use std::fmt::{self, Debug, Formatter};
use std::path::PathBuf;
use either::Either;
use sqlx_core::ConnectOptions;
mod builder;
mod default;
mod getters;
mod parse;
/// Options which can be used to configure how a Postgres connection is opened.
///
/// A value of `PgConnectOptions` can be parsed from a connection URI, as
/// described by [libpq](https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING).
///
/// ```text
/// postgresql://[user[:password]@][host][:port][/database][?param1=value1&...]
/// ```
///
#[allow(clippy::module_name_repetitions)]
#[derive(Clone)]
pub struct PgConnectOptions {
pub(crate) address: Either<(String, u16), PathBuf>,
username: Option<String>,
password: Option<String>,
database: Option<String>,
application_name: Option<String>,
}
impl Debug for PgConnectOptions {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("PgConnectOptions")
.field(
"address",
&self
.address
.as_ref()
.map_left(|(host, port)| format!("{}:{}", host, port))
.map_right(|socket| socket.display()),
)
.field("username", &self.username)
.field("password", &self.password)
.field("database", &self.database)
.field("application_name", &self.application_name)
.finish()
}
}
impl ConnectOptions for PgConnectOptions {}
#[cfg(feature = "blocking")]
impl sqlx_core::blocking::ConnectOptions for PgConnectOptions {}

View File

@@ -0,0 +1,97 @@
use std::mem;
use std::path::{Path, PathBuf};
use either::Either;
impl super::PgConnectOptions {
/// Sets the hostname of the database server.
///
/// If the hostname begins with a slash (`/`), it is interpreted as the absolute path
/// to a Unix domain socket file instead of a hostname of a server.
///
/// Defaults to either the `PGHOSTADDR` or `PGHOST` environment variable, falling back
/// to `localhost` if neither are present.
///
pub fn host(&mut self, host: impl AsRef<str>) -> &mut Self {
let host = host.as_ref();
self.address = if host.starts_with('/') {
Either::Right(PathBuf::from(&*host))
} else {
Either::Left((host.into(), self.get_port()))
};
self
}
/// Sets the path of the Unix domain socket to connect to.
///
/// Overrides [`host`](#method.host).
///
/// Defaults to, and overrides a default `host`, if one of the files is present in
/// the local filesystem:
///
/// - `/var/run/postgresql/.s.PGSQL.{port}`
/// - `/private/tmp/.s.PGSQL.{port}`
/// - `/tmp/.s.PGSQL.{port}`
///
pub fn socket(&mut self, socket: impl AsRef<Path>) -> &mut Self {
self.address = Either::Right(socket.as_ref().to_owned());
self
}
/// Sets the TCP port number of the database server.
///
/// Defaults to the `PGPORT` environment variable, falling back to `5432`
/// if not present.
///
pub fn port(&mut self, port: u16) -> &mut Self {
self.address = match self.address {
Either::Right(_) => Either::Left(("localhost".to_owned(), port)),
Either::Left((ref mut host, _)) => Either::Left((mem::take(host), port)),
};
self
}
/// Sets the user to be used for authentication.
///
/// Defaults to the `PGUSER` environment variable, if present.
///
pub fn username(&mut self, username: impl AsRef<str>) -> &mut Self {
self.username = Some(username.as_ref().to_owned());
self
}
/// Sets the password to be used for authentication.
///
/// Defaults to the `PGPASSWORD` environment variable, if present.
///
pub fn password(&mut self, password: impl AsRef<str>) -> &mut Self {
self.password = Some(password.as_ref().to_owned());
self
}
/// Sets the database for the connection.
///
/// Defaults to the `PGDATABASE` environment variable, falling back to
/// the name of the user, if not present.
///
pub fn database(&mut self, database: impl AsRef<str>) -> &mut Self {
self.database = Some(database.as_ref().to_owned());
self
}
/// Sets the application name for the connection.
///
/// The name will be displayed in the `pg_stat_activity` view and
/// included in CSV log entries. Only printable ASCII characters may be
/// used in the `application_name` value.
///
/// Defaults to the `PGAPPNAME` environment variable, if present.
///
pub fn application_name(&mut self, name: impl AsRef<str>) -> &mut Self {
self.application_name = Some(name.as_ref().to_owned());
self
}
}

View File

@@ -0,0 +1,57 @@
use std::env::var;
use std::path::{Path, PathBuf};
use either::Either;
use crate::PgConnectOptions;
pub(crate) const HOST: &str = "localhost";
pub(crate) const PORT: u16 = 5432;
impl Default for PgConnectOptions {
fn default() -> Self {
let port = var("PGPORT").ok().and_then(|v| v.parse().ok()).unwrap_or(PORT);
let mut self_ = Self {
address: default_address(port),
username: var("PGUSER").ok(),
password: var("PGPASSWORD").ok(),
database: var("PGDATABASE").ok(),
application_name: var("PGAPPNAME").ok(),
};
if let Some(host) = var("PGHOSTADDR").ok().or_else(|| var("PGHOST").ok()) {
// apply PGHOST down here to let someone set a socket
// path via PGHOST
self_.host(&host);
}
self_
}
}
impl PgConnectOptions {
/// Creates a default set of options ready for configuration.
#[must_use]
pub fn new() -> Self {
Self::default()
}
}
fn default_address(port: u16) -> Either<(String, u16), PathBuf> {
// try to check for the existence of a unix socket and uses that
let socket = format!(".s.PGSQL.{}", port);
let candidates = [
"/var/run/postgresql", // Debian
"/private/tmp", // OSX (homebrew)
"/tmp", // Default
];
for candidate in &candidates {
if Path::new(candidate).join(&socket).exists() {
return Either::Right(PathBuf::from(candidate));
}
}
Either::Left((HOST.to_owned(), port))
}

View File

@@ -0,0 +1,47 @@
use std::path::{Path, PathBuf};
use super::{default, PgConnectOptions};
impl PgConnectOptions {
/// Returns the hostname of the database server.
#[must_use]
pub fn get_host(&self) -> &str {
self.address.as_ref().left().map_or(default::HOST, |(host, _)| &**host)
}
/// Returns the TCP port number of the database server.
#[must_use]
pub fn get_port(&self) -> u16 {
self.address.as_ref().left().map_or(default::PORT, |(_, port)| *port)
}
/// Returns the path to the Unix domain socket, if one is configured.
#[must_use]
pub fn get_socket(&self) -> Option<&Path> {
self.address.as_ref().right().map(PathBuf::as_path)
}
/// Returns the default database name.
#[must_use]
pub fn get_database(&self) -> Option<&str> {
self.database.as_deref()
}
/// Returns the username to be used for authentication.
#[must_use]
pub fn get_username(&self) -> Option<&str> {
self.username.as_deref()
}
/// Returns the password to be used for authentication.
#[must_use]
pub fn get_password(&self) -> Option<&str> {
self.password.as_deref()
}
/// Returns the application name for the connection.
#[must_use]
pub fn get_application_name(&self) -> Option<&str> {
self.application_name.as_deref()
}
}

View File

@@ -0,0 +1,174 @@
use std::borrow::Cow;
use std::str::FromStr;
use percent_encoding::percent_decode_str;
use sqlx_core::Error;
use url::Url;
use crate::PgConnectOptions;
impl FromStr for PgConnectOptions {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let url: Url = s.parse().map_err(|error| Error::opt("for database URL", error))?;
if !matches!(url.scheme(), "postgres" | "postgresql") {
return Err(Error::opt_msg(format!(
"unsupported URL scheme {:?} for Postgres",
url.scheme()
)));
}
let mut options = Self::new();
if let Some(host) = url.host_str() {
options.host(percent_decode_str_utf8(host));
}
if let Some(port) = url.port() {
options.port(port);
}
let username = url.username();
if !username.is_empty() {
options.username(percent_decode_str_utf8(username));
}
if let Some(password) = url.password() {
options.password(percent_decode_str_utf8(password));
}
let mut path = url.path();
if path.starts_with('/') {
path = &path[1..];
}
if !path.is_empty() {
options.database(path);
}
for (key, value) in url.query_pairs() {
match &*key {
"host" | "hostaddr" => {
options.host(value);
}
"port" => {
options.port(value.parse().map_err(|err| Error::opt("for port", err))?);
}
"user" | "username" => {
options.username(value);
}
"password" => {
options.password(value);
}
"ssl-mode" | "sslmode" | "sslMode" | "tls" => {
todo!()
}
"socket" => {
options.socket(&*value);
}
"application_name" => {
options.application_name(&*value);
}
_ => {
// ignore unknown connection parameters
// fixme: should we error or warn here?
}
}
}
Ok(options)
}
}
fn percent_decode_str_utf8(value: &str) -> Cow<'_, str> {
percent_decode_str(value).decode_utf8_lossy()
}
#[cfg(test)]
mod tests {
use std::path::Path;
use super::PgConnectOptions;
#[test]
fn parse() {
let url = "postgresql://user:password@hostname:8915/database?application_name=sqlx";
let options: PgConnectOptions = url.parse().unwrap();
assert_eq!(options.get_username(), Some("user"));
assert_eq!(options.get_password(), Some("password"));
assert_eq!(options.get_host(), "hostname");
assert_eq!(options.get_port(), 8915);
assert_eq!(options.get_database(), Some("database"));
assert_eq!(options.get_application_name(), Some("sqlx"));
}
#[test]
fn parse_with_defaults() {
let url = "postgres://";
let options: PgConnectOptions = url.parse().unwrap();
assert_eq!(options.get_username(), None);
assert_eq!(options.get_password(), None);
assert_eq!(options.get_host(), "localhost");
assert_eq!(options.get_port(), 5432);
assert_eq!(options.get_database(), None);
assert_eq!(options.get_application_name(), None);
}
#[test]
fn parse_socket_from_query() {
let url = "postgresql://user:password@localhost/database?socket=/var/run/postgresql.sock";
let options: PgConnectOptions = url.parse().unwrap();
assert_eq!(options.get_username(), Some("user"));
assert_eq!(options.get_password(), Some("password"));
assert_eq!(options.get_database(), Some("database"));
assert_eq!(options.get_socket(), Some(Path::new("/var/run/postgresql.sock")));
}
#[test]
fn parse_socket_from_host() {
// socket path in host requires URL encoding - but does work
let url = "postgres://user:password@%2Fvar%2Frun%2Fpostgres%2Fpostgres.sock/database";
let options: PgConnectOptions = url.parse().unwrap();
assert_eq!(options.get_username(), Some("user"));
assert_eq!(options.get_password(), Some("password"));
assert_eq!(options.get_database(), Some("database"));
assert_eq!(options.get_socket(), Some(Path::new("/var/run/postgres/postgres.sock")));
}
#[test]
#[should_panic]
fn fail_to_parse_non_postgres() {
let url = "mysql://user:password@hostname:5432/database?timezone=system&charset=utf8";
let _: PgConnectOptions = url.parse().unwrap();
}
#[test]
fn parse_username_with_at_sign() {
let url = "postgres://user@hostname:password@hostname:5432/database";
let options: PgConnectOptions = url.parse().unwrap();
assert_eq!(options.get_username(), Some("user@hostname"));
}
#[test]
fn parse_password_with_non_ascii_chars() {
let url = "postgres://username:p@ssw0rd@hostname:5432/database";
let options: PgConnectOptions = url.parse().unwrap();
assert_eq!(options.get_password(), Some("p@ssw0rd"));
}
}