feat: add sasl support

This commit is contained in:
Daniel Akhterov 2021-01-24 13:54:57 -08:00
parent c8f7601ad1
commit 80a1b19db9
No known key found for this signature in database
GPG Key ID: 80408CD2586A5A52
9 changed files with 373 additions and 27 deletions

46
Cargo.lock generated
View File

@ -382,6 +382,16 @@ dependencies = [
"lazy_static",
]
[[package]]
name = "crypto-mac"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4857fd85a0c34b3c3297875b747c1e02e06b6a0ea32dd892d8192b9ce0813ea6"
dependencies = [
"generic-array",
"subtle",
]
[[package]]
name = "digest"
version = "0.9.0"
@ -551,6 +561,16 @@ dependencies = [
"libc",
]
[[package]]
name = "hmac"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1441c6b1e930e2817404b5046f1f989899143a12bf92de603b69f4e0aee1e15"
dependencies = [
"crypto-mac",
"digest",
]
[[package]]
name = "idna"
version = "0.2.0"
@ -641,10 +661,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ffc5c5338469d4d3ea17d269fa8ea3512ad247247c30bd2df69e68309ed0a08"
[[package]]
name = "md5"
version = "0.7.0"
name = "md-5"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771"
checksum = "7b5a279bb9607f9f53c22d496eade00d138d1bdcccd07d74650387cf94942a15"
dependencies = [
"block-buffer",
"digest",
"opaque-debug",
]
[[package]]
name = "memchr"
@ -1131,13 +1156,15 @@ dependencies = [
"byteorder",
"bytes",
"bytestring",
"crypto-mac",
"either",
"futures-executor",
"futures-io",
"futures-util",
"hmac",
"itoa",
"log",
"md5",
"md-5",
"memchr",
"percent-encoding",
"rand",
@ -1145,9 +1172,20 @@ dependencies = [
"sha-1",
"sha2",
"sqlx-core",
"stringprep",
"url",
]
[[package]]
name = "stringprep"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ee348cb74b87454fff4b551cbf727025810a004f88aeacae7f85b87f4e9a1c1"
dependencies = [
"unicode-bidi",
"unicode-normalization",
]
[[package]]
name = "subtle"
version = "2.4.0"

View File

@ -45,8 +45,11 @@ base64 = "0.13.0"
rand = "0.7"
itoa = "0.4.7"
atoi = "0.4.0"
byteorder = "1.4.2"
md5 = "0.7.0"
byteorder = { version = "1.4.2", default-features = false, features = [ "std" ] }
md-5 = { version = "0.9.1", default-features = false }
hmac = { version = "0.10.1", default-features = false }
stringprep = "0.1.2"
crypto-mac = "0.10.0"
[dev-dependencies]
sqlx-core = { version = "0.6.0-pre", path = "../sqlx-core", features = ["_mock"] }

View File

@ -6,6 +6,9 @@ use sqlx_core::{Close, Connect, Connection, Runtime};
use crate::{Postgres, PostgresConnectOptions};
#[macro_use]
mod sasl;
mod close;
mod connect;
mod ping;

View File

@ -11,14 +11,17 @@
//!
//! https://dev.postgres.com/doc/internals/en/connection-phase.html
//!
use hmac::{Hmac, Mac, NewMac};
use sha2::{Digest, Sha256};
use sqlx_core::net::Stream as NetStream;
use sqlx_core::Error;
use sqlx_core::Result;
use crate::protocol::{
Authentication, BackendKeyData, Message, MessageType, Password, ReadyForQuery, Startup,
Authentication, BackendKeyData, Message, MessageType, Password, ReadyForQuery,
SaslInitialResponse, SaslResponse, Startup,
};
use crate::{PostgresConnectOptions, PostgresConnection};
use crate::{PostgresConnectOptions, PostgresConnection, PostgresDatabaseError};
macro_rules! connect {
(@blocking @tcp $options:ident) => {
@ -48,10 +51,7 @@ macro_rules! connect {
// To begin a session, a frontend opens a connection to the server
// and sends a startup message.
let mut params = vec![
// Sets the display format for date and time values,
// as well as the rules for interpreting ambiguous date input values.
("DateStyle", "ISO, MDY"),
let mut params = vec![ // Sets the display format for date and time values, as well as the rules for interpreting ambiguous date input values. ("DateStyle", "ISO, MDY"),
// Sets the client-side encoding (character set).
// <https://www.postgresql.org/docs/devel/multibyte.html#MULTIBYTE-CHARSET-SUPPORTED>
("client_encoding", "UTF8"),
@ -114,9 +114,9 @@ macro_rules! connect {
})?;
}
// Authentication::Sasl(body) => {
// sasl::authenticate(&mut stream, $options, body).await?;
// }
Authentication::Sasl(body) => {
sasl_authenticate!($(@$blocking)? self_, $options, body)
}
method => {
return Err(Error::configuration_msg(format!(

View File

@ -0,0 +1,239 @@
use hmac::{Hmac, Mac, NewMac};
use rand::Rng;
use sha2::digest::Digest;
use sha2::Sha256;
use sqlx_core::Error;
use sqlx_core::Result;
use crate::protocol::{
Authentication, AuthenticationSasl, MessageType, SaslInitialResponse, SaslResponse,
};
pub(super) const GS2_HEADER: &str = "n,,";
pub(super) const CHANNEL_ATTR: &str = "c";
pub(super) const USERNAME_ATTR: &str = "n";
pub(super) const CLIENT_PROOF_ATTR: &str = "p";
pub(super) const NONCE_ATTR: &str = "r";
macro_rules! sasl_authenticate {
(@blocking @packet $self:ident) => {
$self.read_packet()?;
};
(@packet $self:ident) => {
$self.read_packet_async().await?;
};
($(@$blocking:ident)? $self:ident, $options:ident, $data:ident) => {{
let mut has_sasl = false;
let mut has_sasl_plus = false;
let mut unknown = Vec::new();
for mechanism in $data.mechanisms() {
match mechanism {
"SCRAM-SHA-256" => {
has_sasl = true;
}
"SCRAM-SHA-256-PLUS" => {
has_sasl_plus = true;
}
_ => {
unknown.push(mechanism.to_owned());
}
}
}
if !has_sasl_plus && !has_sasl {
return Err(Error::connect(crate::PostgresDatabaseError::protocol(format!(
"unsupported SASL authentication mechanisms: {}",
unknown.join(", ")
))));
}
// channel-binding = "c=" base64
let channel_binding = format!("{}={}", crate::connection::sasl::CHANNEL_ATTR, base64::encode(crate::connection::sasl::GS2_HEADER));
// "n=" saslname ;; Usernames are prepared using SASLprep.
let username = format!("{}={}", crate::connection::sasl::USERNAME_ATTR, $options.get_username().unwrap_or_default());
let username = match stringprep::saslprep(&username) {
Ok(v) => v,
Err(err) => {
return Err(Error::connect(crate::PostgresDatabaseError::protocol(format!(
"failed to sasl prep the username: {:?}", err
))));
}
};
// nonce = "r=" c-nonce [s-nonce] ;; Second part provided by server.
let nonce = crate::connection::sasl::gen_nonce();
// client-first-message-bare = [reserved-mext ","] username "," nonce ["," extensions]
let client_first_message_bare =
format!("{username},{nonce}", username = username, nonce = nonce);
let client_first_message = format!(
"{gs2_header}{client_first_message_bare}",
gs2_header = crate::connection::sasl::GS2_HEADER,
client_first_message_bare = client_first_message_bare
);
$self.write_packet(&crate::protocol::SaslInitialResponse { response: &client_first_message, plus: false })?;
let message: Message = sasl_authenticate!($(@$blocking)? @packet $self);
let cont = match message.r#type {
MessageType::Authentication => {
match message.decode()? {
Authentication::SaslContinue(data) => data,
auth => {
return Err(Error::connect(PostgresDatabaseError::protocol(format!(
"expected SASLContinue but received {:?}", auth
))));
}
}
}
_ => {
todo!()
}
};
// SaltedPassword := Hi(Normalize(password), salt, i)
let salted_password =
crate::connection::sasl::hi($options.get_password().unwrap_or_default(), &cont.salt, cont.iterations)?;
// ClientKey := HMAC(SaltedPassword, "Client Key")
let mut mac = Hmac::<Sha256>::new_varkey(&salted_password)
.map_err(|err| Error::connect(crate::PostgresDatabaseError::from(err)))?;
mac.update(b"Client Key");
let client_key = mac.finalize().into_bytes();
// StoredKey := H(ClientKey)
let stored_key = Sha256::digest(&client_key);
// client-final-message-without-proof
let client_final_message_wo_proof = format!(
"{channel_binding},r={nonce}",
channel_binding = channel_binding,
nonce = &cont.nonce
);
// AuthMessage := client-first-message-bare + "," + server-first-message + "," + client-final-message-without-proof
let auth_message = format!(
"{client_first_message_bare},{server_first_message},{client_final_message_wo_proof}",
client_first_message_bare = client_first_message_bare,
server_first_message = cont.message,
client_final_message_wo_proof = client_final_message_wo_proof
);
// ClientSignature := HMAC(StoredKey, AuthMessage)
let mut mac = Hmac::<Sha256>::new_varkey(&stored_key)
.map_err(|err| Error::connect(crate::PostgresDatabaseError::from(err)))?;
mac.update(&auth_message.as_bytes());
let client_signature = mac.finalize().into_bytes();
// ClientProof := ClientKey XOR ClientSignature
let client_proof: Vec<u8> =
client_key.iter().zip(client_signature.iter()).map(|(&a, &b)| a ^ b).collect();
// ServerKey := HMAC(SaltedPassword, "Server Key")
let mut mac = Hmac::<Sha256>::new_varkey(&salted_password)
.map_err(|err| Error::connect(crate::PostgresDatabaseError::from(err)))?;
mac.update(b"Server Key");
let server_key = mac.finalize().into_bytes();
// ServerSignature := HMAC(ServerKey, AuthMessage)
let mut mac = Hmac::<Sha256>::new_varkey(&server_key)
.map_err(|err| Error::connect(crate::PostgresDatabaseError::from(err)))?;
mac.update(&auth_message.as_bytes());
// client-final-message = client-final-message-without-proof "," proof
let client_final_message = format!(
"{client_final_message_wo_proof},{client_proof_attr}={client_proof}",
client_final_message_wo_proof = client_final_message_wo_proof,
client_proof_attr = crate::connection::sasl::CLIENT_PROOF_ATTR,
client_proof = base64::encode(&client_proof)
);
$self.write_packet(&crate::protocol::SaslResponse(&client_final_message))?;
let message: Message = sasl_authenticate!($(@$blocking)? @packet $self);
let data = match message.r#type {
MessageType::Authentication => {
match message.decode()? {
Authentication::SaslFinal(data) => data,
auth => {
return Err(Error::connect(PostgresDatabaseError::protocol(format!(
"expected SASLContinue but received {:?}", auth
))));
}
}
}
r#type => {
return Err(Error::connect(PostgresDatabaseError::protocol(format!(
"Expected an authencation message type, found {:?}", r#type
))));
}
};
// authentication is only considered valid if this verification passes
mac.verify(&data.verifier)
.map_err(|err| Error::connect(crate::PostgresDatabaseError::from(err)))?;
}};
}
// nonce is a sequence of random printable bytes
pub(super) fn gen_nonce() -> String {
let mut rng = rand::thread_rng();
let count = rng.gen_range(64, 128);
// printable = %x21-2B / %x2D-7E
// ;; Printable ASCII except ",".
// ;; Note that any "printable" is also
// ;; a valid "value".
let nonce: String = std::iter::repeat(())
.map(|()| {
let mut c = rng.gen_range(0x21, 0x7F) as u8;
while c == 0x2C {
c = rng.gen_range(0x21, 0x7F) as u8;
}
c
})
.take(count)
.map(|c| c as char)
.collect();
rng.gen_range(32, 128);
format!("{}={}", crate::connection::sasl::NONCE_ATTR, nonce)
}
// Hi(str, salt, i):
pub(super) fn hi<'a>(s: &'a str, salt: &'a [u8], iter_count: u32) -> Result<[u8; 32]> {
let mut mac = Hmac::<Sha256>::new_varkey(s.as_bytes())
.map_err(|err| Error::connect(crate::PostgresDatabaseError::from(err)))?;
mac.update(&salt);
mac.update(&1u32.to_be_bytes());
let mut u = mac.finalize().into_bytes();
let mut hi = u;
for _ in 1..iter_count {
let mut mac = Hmac::<Sha256>::new_varkey(s.as_bytes())
.map_err(|err| Error::connect(crate::PostgresDatabaseError::from(err)))?;
mac.update(u.as_slice());
u = mac.finalize().into_bytes();
hi = hi.iter().zip(u.iter()).map(|(&a, &b)| a ^ b).collect();
}
Ok(hi.into())
}

View File

@ -6,17 +6,35 @@ use sqlx_core::DatabaseError;
/// An error returned from the PostgreSQL database server.
#[allow(clippy::module_name_repetitions)]
#[derive(Debug)]
pub struct PostgresDatabaseError();
pub struct PostgresDatabaseError(String);
impl PostgresDatabaseError {
pub(crate) fn protocol(msg: String) -> PostgresDatabaseError {
PostgresDatabaseError(msg)
}
}
impl From<crypto_mac::InvalidKeyLength> for PostgresDatabaseError {
fn from(err: crypto_mac::InvalidKeyLength) -> Self {
PostgresDatabaseError::protocol(err.to_string())
}
}
impl From<crypto_mac::MacError> for PostgresDatabaseError {
fn from(err: crypto_mac::MacError) -> Self {
PostgresDatabaseError::protocol(err.to_string())
}
}
impl DatabaseError for PostgresDatabaseError {
fn message(&self) -> &str {
todo!()
&self.0
}
}
impl Display for PostgresDatabaseError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "TODO")
write!(f, "{}", self.0)
}
}

View File

@ -12,6 +12,7 @@ mod notification;
mod password;
mod ready_for_query;
mod response;
mod sasl;
mod startup;
mod terminate;
@ -22,6 +23,7 @@ pub(crate) use notification::Notification;
pub(crate) use password::Password;
pub(crate) use ready_for_query::ReadyForQuery;
pub(crate) use response::{Notice, PgSeverity};
pub(crate) use sasl::{SaslInitialResponse, SaslResponse};
pub(crate) use startup::Startup;
pub(crate) use terminate::Terminate;

View File

@ -1,5 +1,6 @@
use std::fmt::Write;
use md5::{Digest, Md5};
use sqlx_core::io::Serialize;
use sqlx_core::io::WriteExt;
use sqlx_core::Result;
@ -40,21 +41,23 @@ impl Serialize<'_, ()> for Password<'_> {
// Keep in mind the md5() function returns its result as a hex string.
let digest = md5::compute(password);
let digest = md5::compute(username);
let mut hasher = Md5::new();
let mut outwrite = String::with_capacity(35);
hasher.update(password);
hasher.update(username);
let _ = write!(outwrite, "{:x}", digest);
let mut output = String::with_capacity(35);
let digest = md5::compute(&outwrite);
let digest = md5::compute(salt);
let _ = write!(output, "{:x}", hasher.finalize_reset());
outwrite.clear();
hasher.update(&output);
hasher.update(salt);
let _ = write!(outwrite, "md5{:x}", digest);
output.clear();
buf.write_str_nul(&outwrite);
let _ = write!(output, "md5{:x}", hasher.finalize());
buf.write_str_nul(&output);
}
}
});

View File

@ -0,0 +1,40 @@
use sqlx_core::io::Serialize;
use sqlx_core::io::WriteExt;
use sqlx_core::Result;
use crate::io::PgBufMutExt;
#[derive(Debug)]
pub struct SaslInitialResponse<'a> {
pub response: &'a str,
pub plus: bool,
}
impl Serialize<'_, ()> for SaslInitialResponse<'_> {
fn serialize_with(&self, buf: &mut Vec<u8>, _: ()) -> Result<()> {
buf.push(b'p');
buf.write_length_prefixed(|buf| {
// name of the SASL authentication mechanism that the client selected
buf.write_str_nul(if self.plus { "SCRAM-SHA-256-PLUS" } else { "SCRAM-SHA-256" });
buf.extend(&(self.response.as_bytes().len() as i32).to_be_bytes());
buf.extend(self.response.as_bytes());
});
Ok(())
}
}
#[derive(Debug)]
pub struct SaslResponse<'a>(pub &'a str);
impl Serialize<'_, ()> for SaslResponse<'_> {
fn serialize_with(&self, buf: &mut Vec<u8>, _: ()) -> Result<()> {
buf.push(b'p');
buf.write_length_prefixed(|buf| {
buf.extend(self.0.as_bytes());
});
Ok(())
}
}