From 80a1b19db91fb2972907eac2b1f9d5469d5695d7 Mon Sep 17 00:00:00 2001 From: Daniel Akhterov Date: Sun, 24 Jan 2021 13:54:57 -0800 Subject: [PATCH] feat: add sasl support --- Cargo.lock | 46 ++++- sqlx-postgres/Cargo.toml | 7 +- sqlx-postgres/src/connection.rs | 3 + sqlx-postgres/src/connection/connect.rs | 18 +- sqlx-postgres/src/connection/sasl.rs | 239 ++++++++++++++++++++++++ sqlx-postgres/src/error.rs | 24 ++- sqlx-postgres/src/protocol.rs | 2 + sqlx-postgres/src/protocol/password.rs | 21 ++- sqlx-postgres/src/protocol/sasl.rs | 40 ++++ 9 files changed, 373 insertions(+), 27 deletions(-) create mode 100644 sqlx-postgres/src/connection/sasl.rs create mode 100644 sqlx-postgres/src/protocol/sasl.rs diff --git a/Cargo.lock b/Cargo.lock index 82327364..496a9bf1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -382,6 +382,16 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "crypto-mac" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4857fd85a0c34b3c3297875b747c1e02e06b6a0ea32dd892d8192b9ce0813ea6" +dependencies = [ + "generic-array", + "subtle", +] + [[package]] name = "digest" version = "0.9.0" @@ -551,6 +561,16 @@ dependencies = [ "libc", ] +[[package]] +name = "hmac" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1441c6b1e930e2817404b5046f1f989899143a12bf92de603b69f4e0aee1e15" +dependencies = [ + "crypto-mac", + "digest", +] + [[package]] name = "idna" version = "0.2.0" @@ -641,10 +661,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ffc5c5338469d4d3ea17d269fa8ea3512ad247247c30bd2df69e68309ed0a08" [[package]] -name = "md5" -version = "0.7.0" +name = "md-5" +version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771" +checksum = "7b5a279bb9607f9f53c22d496eade00d138d1bdcccd07d74650387cf94942a15" +dependencies = [ + "block-buffer", + "digest", + "opaque-debug", +] [[package]] name = "memchr" @@ -1131,13 +1156,15 @@ dependencies = [ "byteorder", "bytes", "bytestring", + "crypto-mac", "either", "futures-executor", "futures-io", "futures-util", + "hmac", "itoa", "log", - "md5", + "md-5", "memchr", "percent-encoding", "rand", @@ -1145,9 +1172,20 @@ dependencies = [ "sha-1", "sha2", "sqlx-core", + "stringprep", "url", ] +[[package]] +name = "stringprep" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ee348cb74b87454fff4b551cbf727025810a004f88aeacae7f85b87f4e9a1c1" +dependencies = [ + "unicode-bidi", + "unicode-normalization", +] + [[package]] name = "subtle" version = "2.4.0" diff --git a/sqlx-postgres/Cargo.toml b/sqlx-postgres/Cargo.toml index 57e48372..7f1eb05f 100644 --- a/sqlx-postgres/Cargo.toml +++ b/sqlx-postgres/Cargo.toml @@ -45,8 +45,11 @@ base64 = "0.13.0" rand = "0.7" itoa = "0.4.7" atoi = "0.4.0" -byteorder = "1.4.2" -md5 = "0.7.0" +byteorder = { version = "1.4.2", default-features = false, features = [ "std" ] } +md-5 = { version = "0.9.1", default-features = false } +hmac = { version = "0.10.1", default-features = false } +stringprep = "0.1.2" +crypto-mac = "0.10.0" [dev-dependencies] sqlx-core = { version = "0.6.0-pre", path = "../sqlx-core", features = ["_mock"] } diff --git a/sqlx-postgres/src/connection.rs b/sqlx-postgres/src/connection.rs index 55b86835..3527e340 100644 --- a/sqlx-postgres/src/connection.rs +++ b/sqlx-postgres/src/connection.rs @@ -6,6 +6,9 @@ use sqlx_core::{Close, Connect, Connection, Runtime}; use crate::{Postgres, PostgresConnectOptions}; +#[macro_use] +mod sasl; + mod close; mod connect; mod ping; diff --git a/sqlx-postgres/src/connection/connect.rs b/sqlx-postgres/src/connection/connect.rs index 3e01c131..294b751b 100644 --- a/sqlx-postgres/src/connection/connect.rs +++ b/sqlx-postgres/src/connection/connect.rs @@ -11,14 +11,17 @@ //! //! https://dev.postgres.com/doc/internals/en/connection-phase.html //! +use hmac::{Hmac, Mac, NewMac}; +use sha2::{Digest, Sha256}; use sqlx_core::net::Stream as NetStream; use sqlx_core::Error; use sqlx_core::Result; use crate::protocol::{ - Authentication, BackendKeyData, Message, MessageType, Password, ReadyForQuery, Startup, + Authentication, BackendKeyData, Message, MessageType, Password, ReadyForQuery, + SaslInitialResponse, SaslResponse, Startup, }; -use crate::{PostgresConnectOptions, PostgresConnection}; +use crate::{PostgresConnectOptions, PostgresConnection, PostgresDatabaseError}; macro_rules! connect { (@blocking @tcp $options:ident) => { @@ -48,10 +51,7 @@ macro_rules! connect { // To begin a session, a frontend opens a connection to the server // and sends a startup message. - let mut params = vec![ - // Sets the display format for date and time values, - // as well as the rules for interpreting ambiguous date input values. - ("DateStyle", "ISO, MDY"), + let mut params = vec![ // Sets the display format for date and time values, as well as the rules for interpreting ambiguous date input values. ("DateStyle", "ISO, MDY"), // Sets the client-side encoding (character set). // ("client_encoding", "UTF8"), @@ -114,9 +114,9 @@ macro_rules! connect { })?; } - // Authentication::Sasl(body) => { - // sasl::authenticate(&mut stream, $options, body).await?; - // } + Authentication::Sasl(body) => { + sasl_authenticate!($(@$blocking)? self_, $options, body) + } method => { return Err(Error::configuration_msg(format!( diff --git a/sqlx-postgres/src/connection/sasl.rs b/sqlx-postgres/src/connection/sasl.rs new file mode 100644 index 00000000..99f91e42 --- /dev/null +++ b/sqlx-postgres/src/connection/sasl.rs @@ -0,0 +1,239 @@ +use hmac::{Hmac, Mac, NewMac}; +use rand::Rng; +use sha2::digest::Digest; +use sha2::Sha256; +use sqlx_core::Error; +use sqlx_core::Result; + +use crate::protocol::{ + Authentication, AuthenticationSasl, MessageType, SaslInitialResponse, SaslResponse, +}; + +pub(super) const GS2_HEADER: &str = "n,,"; +pub(super) const CHANNEL_ATTR: &str = "c"; +pub(super) const USERNAME_ATTR: &str = "n"; +pub(super) const CLIENT_PROOF_ATTR: &str = "p"; +pub(super) const NONCE_ATTR: &str = "r"; + +macro_rules! sasl_authenticate { + (@blocking @packet $self:ident) => { + $self.read_packet()?; + }; + + (@packet $self:ident) => { + $self.read_packet_async().await?; + }; + + ($(@$blocking:ident)? $self:ident, $options:ident, $data:ident) => {{ + let mut has_sasl = false; + let mut has_sasl_plus = false; + let mut unknown = Vec::new(); + + for mechanism in $data.mechanisms() { + match mechanism { + "SCRAM-SHA-256" => { + has_sasl = true; + } + + "SCRAM-SHA-256-PLUS" => { + has_sasl_plus = true; + } + + _ => { + unknown.push(mechanism.to_owned()); + } + } + } + + if !has_sasl_plus && !has_sasl { + return Err(Error::connect(crate::PostgresDatabaseError::protocol(format!( + "unsupported SASL authentication mechanisms: {}", + unknown.join(", ") + )))); + } + + // channel-binding = "c=" base64 + let channel_binding = format!("{}={}", crate::connection::sasl::CHANNEL_ATTR, base64::encode(crate::connection::sasl::GS2_HEADER)); + + // "n=" saslname ;; Usernames are prepared using SASLprep. + let username = format!("{}={}", crate::connection::sasl::USERNAME_ATTR, $options.get_username().unwrap_or_default()); + let username = match stringprep::saslprep(&username) { + Ok(v) => v, + Err(err) => { + return Err(Error::connect(crate::PostgresDatabaseError::protocol(format!( + "failed to sasl prep the username: {:?}", err + )))); + } + }; + + // nonce = "r=" c-nonce [s-nonce] ;; Second part provided by server. + let nonce = crate::connection::sasl::gen_nonce(); + + // client-first-message-bare = [reserved-mext ","] username "," nonce ["," extensions] + let client_first_message_bare = + format!("{username},{nonce}", username = username, nonce = nonce); + + let client_first_message = format!( + "{gs2_header}{client_first_message_bare}", + gs2_header = crate::connection::sasl::GS2_HEADER, + client_first_message_bare = client_first_message_bare + ); + + $self.write_packet(&crate::protocol::SaslInitialResponse { response: &client_first_message, plus: false })?; + + let message: Message = sasl_authenticate!($(@$blocking)? @packet $self); + let cont = match message.r#type { + MessageType::Authentication => { + match message.decode()? { + Authentication::SaslContinue(data) => data, + + auth => { + return Err(Error::connect(PostgresDatabaseError::protocol(format!( + "expected SASLContinue but received {:?}", auth + )))); + } + } + } + + _ => { + todo!() + } + }; + + // SaltedPassword := Hi(Normalize(password), salt, i) + let salted_password = + crate::connection::sasl::hi($options.get_password().unwrap_or_default(), &cont.salt, cont.iterations)?; + + // ClientKey := HMAC(SaltedPassword, "Client Key") + let mut mac = Hmac::::new_varkey(&salted_password) + .map_err(|err| Error::connect(crate::PostgresDatabaseError::from(err)))?; + mac.update(b"Client Key"); + + let client_key = mac.finalize().into_bytes(); + + // StoredKey := H(ClientKey) + let stored_key = Sha256::digest(&client_key); + + // client-final-message-without-proof + let client_final_message_wo_proof = format!( + "{channel_binding},r={nonce}", + channel_binding = channel_binding, + nonce = &cont.nonce + ); + + // AuthMessage := client-first-message-bare + "," + server-first-message + "," + client-final-message-without-proof + let auth_message = format!( + "{client_first_message_bare},{server_first_message},{client_final_message_wo_proof}", + client_first_message_bare = client_first_message_bare, + server_first_message = cont.message, + client_final_message_wo_proof = client_final_message_wo_proof + ); + + // ClientSignature := HMAC(StoredKey, AuthMessage) + let mut mac = Hmac::::new_varkey(&stored_key) + .map_err(|err| Error::connect(crate::PostgresDatabaseError::from(err)))?; + mac.update(&auth_message.as_bytes()); + + let client_signature = mac.finalize().into_bytes(); + + // ClientProof := ClientKey XOR ClientSignature + let client_proof: Vec = + client_key.iter().zip(client_signature.iter()).map(|(&a, &b)| a ^ b).collect(); + + // ServerKey := HMAC(SaltedPassword, "Server Key") + let mut mac = Hmac::::new_varkey(&salted_password) + .map_err(|err| Error::connect(crate::PostgresDatabaseError::from(err)))?; + mac.update(b"Server Key"); + + let server_key = mac.finalize().into_bytes(); + + // ServerSignature := HMAC(ServerKey, AuthMessage) + let mut mac = Hmac::::new_varkey(&server_key) + .map_err(|err| Error::connect(crate::PostgresDatabaseError::from(err)))?; + mac.update(&auth_message.as_bytes()); + + // client-final-message = client-final-message-without-proof "," proof + let client_final_message = format!( + "{client_final_message_wo_proof},{client_proof_attr}={client_proof}", + client_final_message_wo_proof = client_final_message_wo_proof, + client_proof_attr = crate::connection::sasl::CLIENT_PROOF_ATTR, + client_proof = base64::encode(&client_proof) + ); + + $self.write_packet(&crate::protocol::SaslResponse(&client_final_message))?; + + let message: Message = sasl_authenticate!($(@$blocking)? @packet $self); + let data = match message.r#type { + MessageType::Authentication => { + match message.decode()? { + Authentication::SaslFinal(data) => data, + + auth => { + return Err(Error::connect(PostgresDatabaseError::protocol(format!( + "expected SASLContinue but received {:?}", auth + )))); + } + } + } + + r#type => { + return Err(Error::connect(PostgresDatabaseError::protocol(format!( + "Expected an authencation message type, found {:?}", r#type + )))); + } + }; + + // authentication is only considered valid if this verification passes + mac.verify(&data.verifier) + .map_err(|err| Error::connect(crate::PostgresDatabaseError::from(err)))?; + }}; +} + +// nonce is a sequence of random printable bytes +pub(super) fn gen_nonce() -> String { + let mut rng = rand::thread_rng(); + let count = rng.gen_range(64, 128); + + // printable = %x21-2B / %x2D-7E + // ;; Printable ASCII except ",". + // ;; Note that any "printable" is also + // ;; a valid "value". + let nonce: String = std::iter::repeat(()) + .map(|()| { + let mut c = rng.gen_range(0x21, 0x7F) as u8; + + while c == 0x2C { + c = rng.gen_range(0x21, 0x7F) as u8; + } + + c + }) + .take(count) + .map(|c| c as char) + .collect(); + + rng.gen_range(32, 128); + format!("{}={}", crate::connection::sasl::NONCE_ATTR, nonce) +} + +// Hi(str, salt, i): +pub(super) fn hi<'a>(s: &'a str, salt: &'a [u8], iter_count: u32) -> Result<[u8; 32]> { + let mut mac = Hmac::::new_varkey(s.as_bytes()) + .map_err(|err| Error::connect(crate::PostgresDatabaseError::from(err)))?; + + mac.update(&salt); + mac.update(&1u32.to_be_bytes()); + + let mut u = mac.finalize().into_bytes(); + let mut hi = u; + + for _ in 1..iter_count { + let mut mac = Hmac::::new_varkey(s.as_bytes()) + .map_err(|err| Error::connect(crate::PostgresDatabaseError::from(err)))?; + mac.update(u.as_slice()); + u = mac.finalize().into_bytes(); + hi = hi.iter().zip(u.iter()).map(|(&a, &b)| a ^ b).collect(); + } + + Ok(hi.into()) +} diff --git a/sqlx-postgres/src/error.rs b/sqlx-postgres/src/error.rs index 94395a14..709d4165 100644 --- a/sqlx-postgres/src/error.rs +++ b/sqlx-postgres/src/error.rs @@ -6,17 +6,35 @@ use sqlx_core::DatabaseError; /// An error returned from the PostgreSQL database server. #[allow(clippy::module_name_repetitions)] #[derive(Debug)] -pub struct PostgresDatabaseError(); +pub struct PostgresDatabaseError(String); + +impl PostgresDatabaseError { + pub(crate) fn protocol(msg: String) -> PostgresDatabaseError { + PostgresDatabaseError(msg) + } +} + +impl From for PostgresDatabaseError { + fn from(err: crypto_mac::InvalidKeyLength) -> Self { + PostgresDatabaseError::protocol(err.to_string()) + } +} + +impl From for PostgresDatabaseError { + fn from(err: crypto_mac::MacError) -> Self { + PostgresDatabaseError::protocol(err.to_string()) + } +} impl DatabaseError for PostgresDatabaseError { fn message(&self) -> &str { - todo!() + &self.0 } } impl Display for PostgresDatabaseError { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!(f, "TODO") + write!(f, "{}", self.0) } } diff --git a/sqlx-postgres/src/protocol.rs b/sqlx-postgres/src/protocol.rs index 5c37dfed..2195504e 100644 --- a/sqlx-postgres/src/protocol.rs +++ b/sqlx-postgres/src/protocol.rs @@ -12,6 +12,7 @@ mod notification; mod password; mod ready_for_query; mod response; +mod sasl; mod startup; mod terminate; @@ -22,6 +23,7 @@ pub(crate) use notification::Notification; pub(crate) use password::Password; pub(crate) use ready_for_query::ReadyForQuery; pub(crate) use response::{Notice, PgSeverity}; +pub(crate) use sasl::{SaslInitialResponse, SaslResponse}; pub(crate) use startup::Startup; pub(crate) use terminate::Terminate; diff --git a/sqlx-postgres/src/protocol/password.rs b/sqlx-postgres/src/protocol/password.rs index 665bbe43..15fcea2b 100644 --- a/sqlx-postgres/src/protocol/password.rs +++ b/sqlx-postgres/src/protocol/password.rs @@ -1,5 +1,6 @@ use std::fmt::Write; +use md5::{Digest, Md5}; use sqlx_core::io::Serialize; use sqlx_core::io::WriteExt; use sqlx_core::Result; @@ -40,21 +41,23 @@ impl Serialize<'_, ()> for Password<'_> { // Keep in mind the md5() function returns its result as a hex string. - let digest = md5::compute(password); - let digest = md5::compute(username); + let mut hasher = Md5::new(); - let mut outwrite = String::with_capacity(35); + hasher.update(password); + hasher.update(username); - let _ = write!(outwrite, "{:x}", digest); + let mut output = String::with_capacity(35); - let digest = md5::compute(&outwrite); - let digest = md5::compute(salt); + let _ = write!(output, "{:x}", hasher.finalize_reset()); - outwrite.clear(); + hasher.update(&output); + hasher.update(salt); - let _ = write!(outwrite, "md5{:x}", digest); + output.clear(); - buf.write_str_nul(&outwrite); + let _ = write!(output, "md5{:x}", hasher.finalize()); + + buf.write_str_nul(&output); } } }); diff --git a/sqlx-postgres/src/protocol/sasl.rs b/sqlx-postgres/src/protocol/sasl.rs new file mode 100644 index 00000000..0165ff6a --- /dev/null +++ b/sqlx-postgres/src/protocol/sasl.rs @@ -0,0 +1,40 @@ +use sqlx_core::io::Serialize; +use sqlx_core::io::WriteExt; +use sqlx_core::Result; + +use crate::io::PgBufMutExt; + +#[derive(Debug)] +pub struct SaslInitialResponse<'a> { + pub response: &'a str, + pub plus: bool, +} + +impl Serialize<'_, ()> for SaslInitialResponse<'_> { + fn serialize_with(&self, buf: &mut Vec, _: ()) -> Result<()> { + buf.push(b'p'); + buf.write_length_prefixed(|buf| { + // name of the SASL authentication mechanism that the client selected + buf.write_str_nul(if self.plus { "SCRAM-SHA-256-PLUS" } else { "SCRAM-SHA-256" }); + + buf.extend(&(self.response.as_bytes().len() as i32).to_be_bytes()); + buf.extend(self.response.as_bytes()); + }); + + Ok(()) + } +} + +#[derive(Debug)] +pub struct SaslResponse<'a>(pub &'a str); + +impl Serialize<'_, ()> for SaslResponse<'_> { + fn serialize_with(&self, buf: &mut Vec, _: ()) -> Result<()> { + buf.push(b'p'); + buf.write_length_prefixed(|buf| { + buf.extend(self.0.as_bytes()); + }); + + Ok(()) + } +}