diff --git a/sqlx-mysql/src/options.rs b/sqlx-mysql/src/options.rs index 425467855..10a79eb80 100644 --- a/sqlx-mysql/src/options.rs +++ b/sqlx-mysql/src/options.rs @@ -13,7 +13,15 @@ mod parse; /// Options which can be used to configure how a MySQL connection is opened. /// +/// A value of `MySqlConnectOptions` can be parsed from a connection URI, as +/// described by [dev.mysql.com](https://dev.mysql.com/doc/refman/8.0/en/connecting-using-uri-or-key-value-pairs.html#connecting-using-uri). +/// +/// ```text +/// mysql://[user[:password]@][host][:port][/database][?param1=value1&...] +/// ``` +/// #[allow(clippy::module_name_repetitions)] +#[derive(Clone)] pub struct MySqlConnectOptions { pub(crate) address: Either<(String, u16), PathBuf>, username: Option, @@ -23,19 +31,6 @@ pub struct MySqlConnectOptions { charset: String, } -impl Clone for MySqlConnectOptions { - fn clone(&self) -> Self { - Self { - address: self.address.clone(), - username: self.username.clone(), - password: self.password.clone(), - database: self.database.clone(), - timezone: self.timezone.clone(), - charset: self.charset.clone(), - } - } -} - impl Debug for MySqlConnectOptions { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.debug_struct("MySqlConnectOptions") diff --git a/sqlx-mysql/src/options/builder.rs b/sqlx-mysql/src/options/builder.rs index 71a067d3b..4a5b8cb8d 100644 --- a/sqlx-mysql/src/options/builder.rs +++ b/sqlx-mysql/src/options/builder.rs @@ -25,7 +25,7 @@ impl super::MySqlConnectOptions { /// Sets the path of the Unix domain socket to connect to. /// - /// Overrides [`host()`](#method.host) and [`port()`](#method.port). + /// Overrides [`host`](#method.host) and [`port`](#method.port). /// pub fn socket(&mut self, socket: impl AsRef) -> &mut Self { self.address = Either::Right(socket.as_ref().to_owned()); @@ -46,7 +46,6 @@ impl super::MySqlConnectOptions { } /// Sets the username to be used for authentication. - // FIXME: Specify what happens when you do NOT set this pub fn username(&mut self, username: impl AsRef) -> &mut Self { self.username = Some(username.as_ref().to_owned()); self diff --git a/sqlx-mysql/src/options/default.rs b/sqlx-mysql/src/options/default.rs index 2b292a53b..745f14ce1 100644 --- a/sqlx-mysql/src/options/default.rs +++ b/sqlx-mysql/src/options/default.rs @@ -14,7 +14,6 @@ impl Default for MySqlConnectOptions { database: None, charset: "utf8mb4".to_owned(), timezone: "utc".to_owned(), - // todo: connect_timeout } } } diff --git a/sqlx-mysql/src/options/parse.rs b/sqlx-mysql/src/options/parse.rs index 581d2532e..9d6b877b7 100644 --- a/sqlx-mysql/src/options/parse.rs +++ b/sqlx-mysql/src/options/parse.rs @@ -59,10 +59,6 @@ impl FromStr for MySqlConnectOptions { options.password(value); } - // ssl-mode compatibly with SQLx <= 0.5 - // sslmode compatibly with PostgreSQL - // sslMode compatibly with JDBC MySQL - // tls compatibly with Go MySQL [preferred] "ssl-mode" | "sslmode" | "sslMode" | "tls" => { todo!() } @@ -90,7 +86,6 @@ impl FromStr for MySqlConnectOptions { } } -// todo: this should probably go somewhere common fn percent_decode_str_utf8(value: &str) -> Cow<'_, str> { percent_decode_str(value).decode_utf8_lossy() } diff --git a/sqlx-postgres/Cargo.toml b/sqlx-postgres/Cargo.toml index 76e5e2ae8..d1fbd0d71 100644 --- a/sqlx-postgres/Cargo.toml +++ b/sqlx-postgres/Cargo.toml @@ -30,6 +30,7 @@ atoi = "0.4.0" sqlx-core = { version = "0.6.0-pre", path = "../sqlx-core" } futures-util = { version = "0.3.8", optional = true } log = "0.4.11" +either = "1.6.1" bytestring = "1.0.0" url = "2.2.0" percent-encoding = "2.1.0" diff --git a/sqlx-postgres/src/lib.rs b/sqlx-postgres/src/lib.rs index 5a34e2798..20a831e02 100644 --- a/sqlx-postgres/src/lib.rs +++ b/sqlx-postgres/src/lib.rs @@ -30,7 +30,7 @@ mod column; mod database; // mod error; // mod io; -// mod options; +mod options; mod output; mod protocol; mod query_result; @@ -49,7 +49,7 @@ pub use column::PgColumn; // pub use connection::PgConnection; pub use database::Postgres; // pub use error::PgDatabaseError; -// pub use options::PgConnectOptions; +pub use options::PgConnectOptions; pub use output::PgOutput; pub use query_result::PgQueryResult; pub use raw_value::{PgRawValue, PgRawValueFormat}; diff --git a/sqlx-postgres/src/options.rs b/sqlx-postgres/src/options.rs new file mode 100644 index 000000000..4093288b5 --- /dev/null +++ b/sqlx-postgres/src/options.rs @@ -0,0 +1,53 @@ +use std::fmt::{self, Debug, Formatter}; +use std::path::PathBuf; + +use either::Either; +use sqlx_core::ConnectOptions; + +mod builder; +mod default; +mod getters; +mod parse; + +/// Options which can be used to configure how a Postgres connection is opened. +/// +/// A value of `PgConnectOptions` can be parsed from a connection URI, as +/// described by [libpq](https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING). +/// +/// ```text +/// postgresql://[user[:password]@][host][:port][/database][?param1=value1&...] +/// ``` +/// +#[allow(clippy::module_name_repetitions)] +#[derive(Clone)] +pub struct PgConnectOptions { + pub(crate) address: Either<(String, u16), PathBuf>, + username: Option, + password: Option, + database: Option, + application_name: Option, +} + +impl Debug for PgConnectOptions { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("PgConnectOptions") + .field( + "address", + &self + .address + .as_ref() + .map_left(|(host, port)| format!("{}:{}", host, port)) + .map_right(|socket| socket.display()), + ) + .field("username", &self.username) + .field("password", &self.password) + .field("database", &self.database) + .field("application_name", &self.application_name) + .finish() + } +} + +impl ConnectOptions for PgConnectOptions {} + +#[cfg(feature = "blocking")] +impl sqlx_core::blocking::ConnectOptions for PgConnectOptions {} diff --git a/sqlx-postgres/src/options/builder.rs b/sqlx-postgres/src/options/builder.rs new file mode 100644 index 000000000..595729d91 --- /dev/null +++ b/sqlx-postgres/src/options/builder.rs @@ -0,0 +1,97 @@ +use std::mem; +use std::path::{Path, PathBuf}; + +use either::Either; + +impl super::PgConnectOptions { + /// Sets the hostname of the database server. + /// + /// If the hostname begins with a slash (`/`), it is interpreted as the absolute path + /// to a Unix domain socket file instead of a hostname of a server. + /// + /// Defaults to either the `PGHOSTADDR` or `PGHOST` environment variable, falling back + /// to `localhost` if neither are present. + /// + pub fn host(&mut self, host: impl AsRef) -> &mut Self { + let host = host.as_ref(); + + self.address = if host.starts_with('/') { + Either::Right(PathBuf::from(&*host)) + } else { + Either::Left((host.into(), self.get_port())) + }; + + self + } + + /// Sets the path of the Unix domain socket to connect to. + /// + /// Overrides [`host`](#method.host). + /// + /// Defaults to, and overrides a default `host`, if one of the files is present in + /// the local filesystem: + /// + /// - `/var/run/postgresql/.s.PGSQL.{port}` + /// - `/private/tmp/.s.PGSQL.{port}` + /// - `/tmp/.s.PGSQL.{port}` + /// + pub fn socket(&mut self, socket: impl AsRef) -> &mut Self { + self.address = Either::Right(socket.as_ref().to_owned()); + self + } + + /// Sets the TCP port number of the database server. + /// + /// Defaults to the `PGPORT` environment variable, falling back to `5432` + /// if not present. + /// + pub fn port(&mut self, port: u16) -> &mut Self { + self.address = match self.address { + Either::Right(_) => Either::Left(("localhost".to_owned(), port)), + Either::Left((ref mut host, _)) => Either::Left((mem::take(host), port)), + }; + + self + } + + /// Sets the user to be used for authentication. + /// + /// Defaults to the `PGUSER` environment variable, if present. + /// + pub fn username(&mut self, username: impl AsRef) -> &mut Self { + self.username = Some(username.as_ref().to_owned()); + self + } + + /// Sets the password to be used for authentication. + /// + /// Defaults to the `PGPASSWORD` environment variable, if present. + /// + pub fn password(&mut self, password: impl AsRef) -> &mut Self { + self.password = Some(password.as_ref().to_owned()); + self + } + + /// Sets the database for the connection. + /// + /// Defaults to the `PGDATABASE` environment variable, falling back to + /// the name of the user, if not present. + /// + pub fn database(&mut self, database: impl AsRef) -> &mut Self { + self.database = Some(database.as_ref().to_owned()); + self + } + + /// Sets the application name for the connection. + /// + /// The name will be displayed in the `pg_stat_activity` view and + /// included in CSV log entries. Only printable ASCII characters may be + /// used in the `application_name` value. + /// + /// Defaults to the `PGAPPNAME` environment variable, if present. + /// + pub fn application_name(&mut self, name: impl AsRef) -> &mut Self { + self.application_name = Some(name.as_ref().to_owned()); + self + } +} diff --git a/sqlx-postgres/src/options/default.rs b/sqlx-postgres/src/options/default.rs new file mode 100644 index 000000000..4bb563d21 --- /dev/null +++ b/sqlx-postgres/src/options/default.rs @@ -0,0 +1,57 @@ +use std::env::var; +use std::path::{Path, PathBuf}; + +use either::Either; + +use crate::PgConnectOptions; + +pub(crate) const HOST: &str = "localhost"; +pub(crate) const PORT: u16 = 5432; + +impl Default for PgConnectOptions { + fn default() -> Self { + let port = var("PGPORT").ok().and_then(|v| v.parse().ok()).unwrap_or(PORT); + + let mut self_ = Self { + address: default_address(port), + username: var("PGUSER").ok(), + password: var("PGPASSWORD").ok(), + database: var("PGDATABASE").ok(), + application_name: var("PGAPPNAME").ok(), + }; + + if let Some(host) = var("PGHOSTADDR").ok().or_else(|| var("PGHOST").ok()) { + // apply PGHOST down here to let someone set a socket + // path via PGHOST + self_.host(&host); + } + + self_ + } +} + +impl PgConnectOptions { + /// Creates a default set of options ready for configuration. + #[must_use] + pub fn new() -> Self { + Self::default() + } +} + +fn default_address(port: u16) -> Either<(String, u16), PathBuf> { + // try to check for the existence of a unix socket and uses that + let socket = format!(".s.PGSQL.{}", port); + let candidates = [ + "/var/run/postgresql", // Debian + "/private/tmp", // OSX (homebrew) + "/tmp", // Default + ]; + + for candidate in &candidates { + if Path::new(candidate).join(&socket).exists() { + return Either::Right(PathBuf::from(candidate)); + } + } + + Either::Left((HOST.to_owned(), port)) +} diff --git a/sqlx-postgres/src/options/getters.rs b/sqlx-postgres/src/options/getters.rs new file mode 100644 index 000000000..d1869932d --- /dev/null +++ b/sqlx-postgres/src/options/getters.rs @@ -0,0 +1,47 @@ +use std::path::{Path, PathBuf}; + +use super::{default, PgConnectOptions}; + +impl PgConnectOptions { + /// Returns the hostname of the database server. + #[must_use] + pub fn get_host(&self) -> &str { + self.address.as_ref().left().map_or(default::HOST, |(host, _)| &**host) + } + + /// Returns the TCP port number of the database server. + #[must_use] + pub fn get_port(&self) -> u16 { + self.address.as_ref().left().map_or(default::PORT, |(_, port)| *port) + } + + /// Returns the path to the Unix domain socket, if one is configured. + #[must_use] + pub fn get_socket(&self) -> Option<&Path> { + self.address.as_ref().right().map(PathBuf::as_path) + } + + /// Returns the default database name. + #[must_use] + pub fn get_database(&self) -> Option<&str> { + self.database.as_deref() + } + + /// Returns the username to be used for authentication. + #[must_use] + pub fn get_username(&self) -> Option<&str> { + self.username.as_deref() + } + + /// Returns the password to be used for authentication. + #[must_use] + pub fn get_password(&self) -> Option<&str> { + self.password.as_deref() + } + + /// Returns the application name for the connection. + #[must_use] + pub fn get_application_name(&self) -> Option<&str> { + self.application_name.as_deref() + } +} diff --git a/sqlx-postgres/src/options/parse.rs b/sqlx-postgres/src/options/parse.rs new file mode 100644 index 000000000..eaf3af8a6 --- /dev/null +++ b/sqlx-postgres/src/options/parse.rs @@ -0,0 +1,174 @@ +use std::borrow::Cow; +use std::str::FromStr; + +use percent_encoding::percent_decode_str; +use sqlx_core::Error; +use url::Url; + +use crate::PgConnectOptions; + +impl FromStr for PgConnectOptions { + type Err = Error; + + fn from_str(s: &str) -> Result { + let url: Url = s.parse().map_err(|error| Error::opt("for database URL", error))?; + + if !matches!(url.scheme(), "postgres" | "postgresql") { + return Err(Error::opt_msg(format!( + "unsupported URL scheme {:?} for Postgres", + url.scheme() + ))); + } + + let mut options = Self::new(); + + if let Some(host) = url.host_str() { + options.host(percent_decode_str_utf8(host)); + } + + if let Some(port) = url.port() { + options.port(port); + } + + let username = url.username(); + if !username.is_empty() { + options.username(percent_decode_str_utf8(username)); + } + + if let Some(password) = url.password() { + options.password(percent_decode_str_utf8(password)); + } + + let mut path = url.path(); + + if path.starts_with('/') { + path = &path[1..]; + } + + if !path.is_empty() { + options.database(path); + } + + for (key, value) in url.query_pairs() { + match &*key { + "host" | "hostaddr" => { + options.host(value); + } + + "port" => { + options.port(value.parse().map_err(|err| Error::opt("for port", err))?); + } + + "user" | "username" => { + options.username(value); + } + + "password" => { + options.password(value); + } + + "ssl-mode" | "sslmode" | "sslMode" | "tls" => { + todo!() + } + + "socket" => { + options.socket(&*value); + } + + "application_name" => { + options.application_name(&*value); + } + + _ => { + // ignore unknown connection parameters + // fixme: should we error or warn here? + } + } + } + + Ok(options) + } +} + +fn percent_decode_str_utf8(value: &str) -> Cow<'_, str> { + percent_decode_str(value).decode_utf8_lossy() +} + +#[cfg(test)] +mod tests { + use std::path::Path; + + use super::PgConnectOptions; + + #[test] + fn parse() { + let url = "postgresql://user:password@hostname:8915/database?application_name=sqlx"; + let options: PgConnectOptions = url.parse().unwrap(); + + assert_eq!(options.get_username(), Some("user")); + assert_eq!(options.get_password(), Some("password")); + assert_eq!(options.get_host(), "hostname"); + assert_eq!(options.get_port(), 8915); + assert_eq!(options.get_database(), Some("database")); + assert_eq!(options.get_application_name(), Some("sqlx")); + } + + #[test] + fn parse_with_defaults() { + let url = "postgres://"; + let options: PgConnectOptions = url.parse().unwrap(); + + assert_eq!(options.get_username(), None); + assert_eq!(options.get_password(), None); + assert_eq!(options.get_host(), "localhost"); + assert_eq!(options.get_port(), 5432); + assert_eq!(options.get_database(), None); + assert_eq!(options.get_application_name(), None); + } + + #[test] + fn parse_socket_from_query() { + let url = "postgresql://user:password@localhost/database?socket=/var/run/postgresql.sock"; + let options: PgConnectOptions = url.parse().unwrap(); + + assert_eq!(options.get_username(), Some("user")); + assert_eq!(options.get_password(), Some("password")); + assert_eq!(options.get_database(), Some("database")); + assert_eq!(options.get_socket(), Some(Path::new("/var/run/postgresql.sock"))); + } + + #[test] + fn parse_socket_from_host() { + // socket path in host requires URL encoding - but does work + let url = "postgres://user:password@%2Fvar%2Frun%2Fpostgres%2Fpostgres.sock/database"; + let options: PgConnectOptions = url.parse().unwrap(); + + assert_eq!(options.get_username(), Some("user")); + assert_eq!(options.get_password(), Some("password")); + assert_eq!(options.get_database(), Some("database")); + assert_eq!(options.get_socket(), Some(Path::new("/var/run/postgres/postgres.sock"))); + } + + #[test] + #[should_panic] + fn fail_to_parse_non_postgres() { + let url = "mysql://user:password@hostname:5432/database?timezone=system&charset=utf8"; + let _: PgConnectOptions = url.parse().unwrap(); + } + + #[test] + fn parse_username_with_at_sign() { + let url = "postgres://user@hostname:password@hostname:5432/database"; + let options: PgConnectOptions = url.parse().unwrap(); + + assert_eq!(options.get_username(), Some("user@hostname")); + } + + #[test] + fn parse_password_with_non_ascii_chars() { + let url = "postgres://username:p@ssw0rd@hostname:5432/database"; + let options: PgConnectOptions = url.parse().unwrap(); + + assert_eq!(options.get_password(), Some("p@ssw0rd")); + } +} diff --git a/x.py b/x.py index 9e80fade1..6e5ddee0c 100755 --- a/x.py +++ b/x.py @@ -172,6 +172,7 @@ def main(): # run unit tests, collect test binary filenames run_unit_test("sqlx-core") run_unit_test("sqlx-mysql") + run_unit_test("sqlx-postgres") run_unit_test("sqlx") if test_object_filenames and argv.coverage: