From b29eab0439b9914fdae20aa6e2ca6af0e5dc4969 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=ADlian?= <69756012+lily-mosquitoes@users.noreply.github.com> Date: Wed, 6 Mar 2024 05:46:49 +0200 Subject: [PATCH] feat: add `to_url_lossy` to connect options (#2902) * feat: add get_url to connect options Add a get_url to connect options and implement it for all needed types; include get_filename for sqlite. These changes make it easier to test sqlx. * refactor: use expect with message * refactor: change method name to `to_url_lossy` * fix: remove unused imports --- sqlx-core/src/any/options.rs | 4 ++ sqlx-core/src/connection.rs | 29 +++++++++ sqlx-mysql/src/options/connect.rs | 4 ++ sqlx-mysql/src/options/parse.rs | 76 +++++++++++++++++++++++- sqlx-postgres/src/options/connect.rs | 4 ++ sqlx-postgres/src/options/parse.rs | 88 +++++++++++++++++++++++++++- sqlx-sqlite/src/options/connect.rs | 4 ++ sqlx-sqlite/src/options/mod.rs | 5 ++ sqlx-sqlite/src/options/parse.rs | 44 +++++++++++++- 9 files changed, 253 insertions(+), 5 deletions(-) diff --git a/sqlx-core/src/any/options.rs b/sqlx-core/src/any/options.rs index dfa677e6..bb29d817 100644 --- a/sqlx-core/src/any/options.rs +++ b/sqlx-core/src/any/options.rs @@ -43,6 +43,10 @@ impl ConnectOptions for AnyConnectOptions { }) } + fn to_url_lossy(&self) -> Url { + self.database_url.clone() + } + #[inline] fn connect(&self) -> BoxFuture<'_, Result> { AnyConnection::connect(self) diff --git a/sqlx-core/src/connection.rs b/sqlx-core/src/connection.rs index f254344a..584e5c47 100644 --- a/sqlx-core/src/connection.rs +++ b/sqlx-core/src/connection.rs @@ -189,6 +189,35 @@ pub trait ConnectOptions: 'static + Send + Sync + FromStr + Debug + /// Parse the `ConnectOptions` from a URL. fn from_url(url: &Url) -> Result; + /// Get a connection URL that may be used to connect to the same database as this `ConnectOptions`. + /// + /// ### Note: Lossy + /// Any flags or settings which do not have a representation in the URL format will be lost. + /// They will fall back to their default settings when the URL is parsed. + /// + /// The only settings guaranteed to be preserved are: + /// * Username + /// * Password + /// * Hostname + /// * Port + /// * Database name + /// * Unix socket or SQLite database file path + /// * SSL mode (if applicable) + /// * SSL CA certificate path + /// * SSL client certificate path + /// * SSL client key path + /// + /// Additional settings are driver-specific. Refer to the source of a given implementation + /// to see which options are preserved in the URL. + /// + /// ### Panics + /// This defaults to `unimplemented!()`. + /// + /// Individual drivers should override this to implement the intended behavior. + fn to_url_lossy(&self) -> Url { + unimplemented!() + } + /// Establish a new database connection with the options specified by `self`. fn connect(&self) -> BoxFuture<'_, Result> where diff --git a/sqlx-mysql/src/options/connect.rs b/sqlx-mysql/src/options/connect.rs index 0b52a761..4c89b439 100644 --- a/sqlx-mysql/src/options/connect.rs +++ b/sqlx-mysql/src/options/connect.rs @@ -14,6 +14,10 @@ impl ConnectOptions for MySqlConnectOptions { Self::parse_from_url(url) } + fn to_url_lossy(&self) -> Url { + self.build_url() + } + fn connect(&self) -> BoxFuture<'_, Result> where Self::Connection: Sized, diff --git a/sqlx-mysql/src/options/parse.rs b/sqlx-mysql/src/options/parse.rs index 5ba5c320..971510ca 100644 --- a/sqlx-mysql/src/options/parse.rs +++ b/sqlx-mysql/src/options/parse.rs @@ -1,9 +1,9 @@ use std::str::FromStr; -use percent_encoding::percent_decode_str; +use percent_encoding::{percent_decode_str, utf8_percent_encode, NON_ALPHANUMERIC}; use sqlx_core::Url; -use crate::error::Error; +use crate::{error::Error, MySqlSslMode}; use super::MySqlConnectOptions; @@ -78,6 +78,65 @@ impl MySqlConnectOptions { Ok(options) } + + pub(crate) fn build_url(&self) -> Url { + let mut url = Url::parse(&format!( + "mysql://{}@{}:{}", + self.username, self.host, self.port + )) + .expect("BUG: generated un-parseable URL"); + + if let Some(password) = &self.password { + let password = utf8_percent_encode(&password, NON_ALPHANUMERIC).to_string(); + let _ = url.set_password(Some(&password)); + } + + if let Some(database) = &self.database { + url.set_path(&database); + } + + let ssl_mode = match self.ssl_mode { + MySqlSslMode::Disabled => "DISABLED", + MySqlSslMode::Preferred => "PREFERRED", + MySqlSslMode::Required => "REQUIRED", + MySqlSslMode::VerifyCa => "VERIFY_CA", + MySqlSslMode::VerifyIdentity => "VERIFY_IDENTITY", + }; + url.query_pairs_mut().append_pair("ssl-mode", ssl_mode); + + if let Some(ssl_ca) = &self.ssl_ca { + url.query_pairs_mut() + .append_pair("ssl-ca", &ssl_ca.to_string()); + } + + url.query_pairs_mut().append_pair("charset", &self.charset); + + if let Some(collation) = &self.collation { + url.query_pairs_mut().append_pair("charset", &collation); + } + + if let Some(ssl_client_cert) = &self.ssl_client_cert { + url.query_pairs_mut() + .append_pair("ssl-cert", &ssl_client_cert.to_string()); + } + + if let Some(ssl_client_key) = &self.ssl_client_key { + url.query_pairs_mut() + .append_pair("ssl-key", &ssl_client_key.to_string()); + } + + url.query_pairs_mut().append_pair( + "statement-cache-capacity", + &self.statement_cache_capacity.to_string(), + ); + + if let Some(socket) = &self.socket { + url.query_pairs_mut() + .append_pair("socket", &socket.to_string_lossy()); + } + + url + } } impl FromStr for MySqlConnectOptions { @@ -104,3 +163,16 @@ fn it_parses_password_with_non_ascii_chars_correctly() { assert_eq!(Some("p@ssw0rd".into()), opts.password); } + +#[test] +fn it_returns_the_parsed_url() { + let url = "mysql://username:p@ssw0rd@hostname:3306/database"; + let opts = MySqlConnectOptions::from_str(url).unwrap(); + + let mut expected_url = Url::parse(url).unwrap(); + // MySqlConnectOptions defaults + let query_string = "ssl-mode=PREFERRED&charset=utf8mb4&statement-cache-capacity=100"; + expected_url.set_query(Some(query_string)); + + assert_eq!(expected_url, opts.build_url()); +} diff --git a/sqlx-postgres/src/options/connect.rs b/sqlx-postgres/src/options/connect.rs index f61909a9..bc6e4adc 100644 --- a/sqlx-postgres/src/options/connect.rs +++ b/sqlx-postgres/src/options/connect.rs @@ -13,6 +13,10 @@ impl ConnectOptions for PgConnectOptions { Self::parse_from_url(url) } + fn to_url_lossy(&self) -> Url { + self.build_url() + } + fn connect(&self) -> BoxFuture<'_, Result> where Self::Connection: Sized, diff --git a/sqlx-postgres/src/options/parse.rs b/sqlx-postgres/src/options/parse.rs index 4c5cf41c..559516c0 100644 --- a/sqlx-postgres/src/options/parse.rs +++ b/sqlx-postgres/src/options/parse.rs @@ -1,6 +1,6 @@ use crate::error::Error; -use crate::PgConnectOptions; -use sqlx_core::percent_encoding::percent_decode_str; +use crate::{PgConnectOptions, PgSslMode}; +use sqlx_core::percent_encoding::{percent_decode_str, utf8_percent_encode, NON_ALPHANUMERIC}; use sqlx_core::Url; use std::net::IpAddr; use std::str::FromStr; @@ -108,6 +108,62 @@ impl PgConnectOptions { Ok(options) } + + pub(crate) fn build_url(&self) -> Url { + let host = match &self.socket { + Some(socket) => { + utf8_percent_encode(&*socket.to_string_lossy(), NON_ALPHANUMERIC).to_string() + } + None => self.host.to_owned(), + }; + + let mut url = Url::parse(&format!( + "postgres://{}@{}:{}", + self.username, host, self.port + )) + .expect("BUG: generated un-parseable URL"); + + if let Some(password) = &self.password { + let password = utf8_percent_encode(&password, NON_ALPHANUMERIC).to_string(); + let _ = url.set_password(Some(&password)); + } + + if let Some(database) = &self.database { + url.set_path(&database); + } + + let ssl_mode = match self.ssl_mode { + PgSslMode::Allow => "ALLOW", + PgSslMode::Disable => "DISABLED", + PgSslMode::Prefer => "PREFERRED", + PgSslMode::Require => "REQUIRED", + PgSslMode::VerifyCa => "VERIFY_CA", + PgSslMode::VerifyFull => "VERIFY_FULL", + }; + url.query_pairs_mut().append_pair("ssl-mode", ssl_mode); + + if let Some(ssl_root_cert) = &self.ssl_root_cert { + url.query_pairs_mut() + .append_pair("ssl-root-cert", &ssl_root_cert.to_string()); + } + + if let Some(ssl_client_cert) = &self.ssl_client_cert { + url.query_pairs_mut() + .append_pair("ssl-cert", &ssl_client_cert.to_string()); + } + + if let Some(ssl_client_key) = &self.ssl_client_key { + url.query_pairs_mut() + .append_pair("ssl-key", &ssl_client_key.to_string()); + } + + url.query_pairs_mut().append_pair( + "statement-cache-capacity", + &self.statement_cache_capacity.to_string(), + ); + + url + } } impl FromStr for PgConnectOptions { @@ -242,3 +298,31 @@ fn it_parses_sqlx_options_correctly() { opts.options ); } + +#[test] +fn it_returns_the_parsed_url_when_socket() { + let url = "postgres://username@%2Fvar%2Flib%2Fpostgres/database"; + let opts = PgConnectOptions::from_str(url).unwrap(); + + let mut expected_url = Url::parse(url).unwrap(); + // PgConnectOptions defaults + let query_string = "ssl-mode=PREFERRED&statement-cache-capacity=100"; + let port = 5432; + expected_url.set_query(Some(query_string)); + let _ = expected_url.set_port(Some(port)); + + assert_eq!(expected_url, opts.build_url()); +} + +#[test] +fn it_returns_the_parsed_url_when_host() { + let url = "postgres://username:p@ssw0rd@hostname:5432/database"; + let opts = PgConnectOptions::from_str(url).unwrap(); + + let mut expected_url = Url::parse(url).unwrap(); + // PgConnectOptions defaults + let query_string = "ssl-mode=PREFERRED&statement-cache-capacity=100"; + expected_url.set_query(Some(query_string)); + + assert_eq!(expected_url, opts.build_url()); +} diff --git a/sqlx-sqlite/src/options/connect.rs b/sqlx-sqlite/src/options/connect.rs index 5545cfa4..309f2430 100644 --- a/sqlx-sqlite/src/options/connect.rs +++ b/sqlx-sqlite/src/options/connect.rs @@ -24,6 +24,10 @@ impl ConnectOptions for SqliteConnectOptions { Self::from_str(url.as_str()) } + fn to_url_lossy(&self) -> Url { + self.build_url() + } + fn connect(&self) -> BoxFuture<'_, Result> where Self::Connection: Sized, diff --git a/sqlx-sqlite/src/options/mod.rs b/sqlx-sqlite/src/options/mod.rs index 33875720..ac45b84e 100644 --- a/sqlx-sqlite/src/options/mod.rs +++ b/sqlx-sqlite/src/options/mod.rs @@ -211,6 +211,11 @@ impl SqliteConnectOptions { self } + /// Gets the current name of the database file. + pub fn get_filename(self) -> Cow<'static, Path> { + self.filename + } + /// Set the enforcement of [foreign key constraints](https://www.sqlite.org/pragma.html#pragma_foreign_keys). /// /// SQLx chooses to enable this by default so that foreign keys function as expected, diff --git a/sqlx-sqlite/src/options/parse.rs b/sqlx-sqlite/src/options/parse.rs index a2cab10b..aab61b9b 100644 --- a/sqlx-sqlite/src/options/parse.rs +++ b/sqlx-sqlite/src/options/parse.rs @@ -1,10 +1,11 @@ use crate::error::Error; use crate::SqliteConnectOptions; -use percent_encoding::percent_decode_str; +use percent_encoding::{percent_decode_str, utf8_percent_encode, NON_ALPHANUMERIC}; use std::borrow::Cow; use std::path::{Path, PathBuf}; use std::str::FromStr; use std::sync::atomic::{AtomicUsize, Ordering}; +use url::Url; // https://www.sqlite.org/uri.html @@ -111,6 +112,36 @@ impl SqliteConnectOptions { Ok(options) } + + pub(crate) fn build_url(&self) -> Url { + let filename = + utf8_percent_encode(&self.filename.to_string_lossy(), NON_ALPHANUMERIC).to_string(); + let mut url = + Url::parse(&format!("sqlite://{}", filename)).expect("BUG: generated un-parseable URL"); + + let mode = match (self.in_memory, self.create_if_missing, self.read_only) { + (true, _, _) => "memory", + (false, true, _) => "rwc", + (false, false, true) => "ro", + (false, false, false) => "rw", + }; + url.query_pairs_mut().append_pair("mode", mode); + + let cache = match self.shared_cache { + true => "shared", + false => "private", + }; + url.query_pairs_mut().append_pair("cache", cache); + + url.query_pairs_mut() + .append_pair("immutable", &self.immutable.to_string()); + + if let Some(vfs) = &self.vfs { + url.query_pairs_mut().append_pair("vfs", &vfs); + } + + url + } } impl FromStr for SqliteConnectOptions { @@ -169,3 +200,14 @@ fn test_parse_shared_in_memory() -> Result<(), Error> { Ok(()) } + +#[test] +fn it_returns_the_parsed_url() -> Result<(), Error> { + let url = "sqlite://test.db?mode=rw&cache=shared"; + let options: SqliteConnectOptions = url.parse()?; + + let expected_url = Url::parse(url).unwrap(); + assert_eq!(options.build_url(), expected_url); + + Ok(()) +}