mirror of
https://github.com/launchbadge/sqlx.git
synced 2025-09-29 14:05:51 +00:00
Resolve Comments
- Remove `hex` from root `Cargo.toml` - Make `hmac` crate optional - Clean up checking mechanisms for "SCRAM-SHA-256" - Use `str::from_utf8` instead of `String::from_utf8_lossyf - Update `Sasl*Response` structs be tuple structs - Factor out `len` in `SaslInitialResponse.encode()` - Use `protocol_err` instead of `expect` when constructing `Hmacf instances - Remove `it_connects_to_database_user` test as it was too fragile - Move `sasl_auth` function into `postgres/connection` as it more related to `Connection` rather than `protocl` - Return an error when decoding base64 salt rather than panicing in `Authentication::SaslContinue`
This commit is contained in:
parent
507d988fc4
commit
db230e2ce0
14
Cargo.lock
generated
14
Cargo.lock
generated
@ -531,6 +531,14 @@ dependencies = [
|
||||
"typenum 1.11.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "generic-array"
|
||||
version = "0.13.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"typenum 1.11.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "getrandom"
|
||||
version = "0.1.14"
|
||||
@ -1209,8 +1217,7 @@ dependencies = [
|
||||
"futures-channel 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"futures-core 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"futures-util 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"generic-array 0.12.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"hex 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"generic-array 0.13.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"hmac 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"log 0.4.8 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"matches 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
@ -1638,7 +1645,8 @@ dependencies = [
|
||||
"checksum futures-util 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)" = "c0d66274fb76985d3c62c886d1da7ac4c0903a8c9f754e8fe0f35a6a6cc39e76"
|
||||
"checksum futures-util-preview 0.3.0-alpha.19 (registry+https://github.com/rust-lang/crates.io-index)" = "5ce968633c17e5f97936bd2797b6e38fb56cf16a7422319f7ec2e30d3c470e8d"
|
||||
"checksum generic-array 0.12.3 (registry+https://github.com/rust-lang/crates.io-index)" = "c68f0274ae0e023facc3c97b2e00f076be70e254bc851d972503b328db79b2ec"
|
||||
"checksum getrandom 0.1.14 (registry+https://github.com/rust-lang/crates.io-index)" = "7abc8dd8451921606d809ba32e95b6111925cd2906060d2dcc29c070220503eb"
|
||||
"checksum generic-array 0.13.2 (registry+https://github.com/rust-lang/crates.io-index)" = "0ed1e761351b56f54eb9dcd0cfaca9fd0daecf93918e1cfc01c8a3d26ee7adcd"
|
||||
"checksum getrandom 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)" = "e7db7ca94ed4cd01190ceee0d8a8052f08a247aa1b469a7f68c6a3b71afcf407"
|
||||
"checksum h2 0.1.26 (registry+https://github.com/rust-lang/crates.io-index)" = "a5b34c246847f938a410a03c5458c7fee2274436675e76d8b903c08efc29c462"
|
||||
"checksum hermit-abi 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)" = "f629dc602392d3ec14bfc8a09b5e644d7ffd725102b48b81e59f90f2633621d7"
|
||||
"checksum hex 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "023b39be39e3a2da62a94feb433e91e8bcd37676fbc8bea371daf52b7a769a3e"
|
||||
|
@ -18,7 +18,7 @@ all-features = true
|
||||
[features]
|
||||
default = []
|
||||
unstable = []
|
||||
postgres = [ "md-5", "sha2", "base64", "sha-1", "rand" ]
|
||||
postgres = [ "md-5", "sha2", "base64", "sha-1", "rand", "hmac" ]
|
||||
mysql = [ "sha-1", "sha2", "generic-array", "num-bigint", "base64", "digest", "rand" ]
|
||||
|
||||
[dependencies]
|
||||
@ -34,7 +34,7 @@ digest = { version = "0.8.1", default-features = false, optional = true, feature
|
||||
futures-channel = { version = "0.3.1", default-features = false }
|
||||
futures-core = { version = "0.3.1", default-features = false }
|
||||
futures-util = { version = "0.3.1", default-features = false }
|
||||
generic-array = { version = "0.12.3", default-features = false, optional = true }
|
||||
generic-array = { version = "0.13.2", default-features = false, optional = true }
|
||||
log = { version = "0.4.8", default-features = false }
|
||||
md-5 = { version = "0.8.0", default-features = false, optional = true }
|
||||
memchr = { version = "2.2.1", default-features = false }
|
||||
@ -44,8 +44,7 @@ sha-1 = { version = "0.8.1", default-features = false, optional = true }
|
||||
sha2 = { version = "0.8.0", default-features = false, optional = true }
|
||||
url = { version = "2.1.0", default-features = false }
|
||||
uuid = { version = "0.8.1", default-features = false, optional = true }
|
||||
hex = "0.4.0"
|
||||
hmac = "0.7.1"
|
||||
hmac = { version = "0.7.1", default-features = false, optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
matches = "0.1.8"
|
||||
|
@ -7,10 +7,13 @@ use futures_core::future::BoxFuture;
|
||||
use crate::cache::StatementCache;
|
||||
use crate::connection::Connection;
|
||||
use crate::io::{Buf, BufStream};
|
||||
use crate::postgres::protocol::{self, Decode, Encode, Message, StatementId};
|
||||
use crate::postgres::protocol::{self, Decode, Encode, Message, StatementId, SaslResponse, SaslInitialResponse, hi, Authentication};
|
||||
use crate::postgres::PgError;
|
||||
use crate::url::Url;
|
||||
use std::ops::Deref;
|
||||
use sha2::{Sha256, Digest};
|
||||
use hmac::{Mac, Hmac};
|
||||
use crate::Result;
|
||||
use rand::Rng;
|
||||
|
||||
/// An asynchronous connection to a [Postgres] database.
|
||||
///
|
||||
@ -38,7 +41,7 @@ pub struct PgConnection {
|
||||
|
||||
impl PgConnection {
|
||||
// https://www.postgresql.org/docs/12/protocol-flow.html#id-1.10.5.7.3
|
||||
async fn startup(&mut self, url: Url) -> crate::Result<()> {
|
||||
async fn startup(&mut self, url: Url) -> Result<()> {
|
||||
// Defaults to postgres@.../postgres
|
||||
let username = url.username().unwrap_or("postgres");
|
||||
let database = url.database().unwrap_or("postgres");
|
||||
@ -94,26 +97,21 @@ impl PgConnection {
|
||||
}
|
||||
|
||||
protocol::Authentication::Sasl { mechanisms } => {
|
||||
let mechanism = (*mechanisms)
|
||||
.get(0)
|
||||
.ok_or(protocol_err!(
|
||||
match mechanisms.get(0).map(|m| &**m) {
|
||||
Some("SCRAM-SHA-256") => {
|
||||
sasl_auth(
|
||||
self,
|
||||
username,
|
||||
url.password().unwrap_or_default(),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
_ => return Err(protocol_err!(
|
||||
"Expected mechanisms SCRAM-SHA-256, but received {:?}",
|
||||
mechanisms
|
||||
))?
|
||||
.deref();
|
||||
if "SCRAM-SHA-256" == &*mechanism {
|
||||
protocol::sasl_auth(
|
||||
self,
|
||||
username,
|
||||
url.password().unwrap_or_default(),
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
Err(protocol_err!(
|
||||
"Expected mechanisms SCRAM-SHA-256, but received {:?}",
|
||||
mechanisms
|
||||
))?
|
||||
}?;
|
||||
).into()),
|
||||
}
|
||||
}
|
||||
|
||||
auth => {
|
||||
@ -146,7 +144,7 @@ impl PgConnection {
|
||||
}
|
||||
|
||||
// https://www.postgresql.org/docs/devel/protocol-flow.html#id-1.10.5.7.10
|
||||
async fn terminate(mut self) -> crate::Result<()> {
|
||||
async fn terminate(mut self) -> Result<()> {
|
||||
protocol::Terminate.encode(self.stream.buffer_mut());
|
||||
|
||||
self.stream.flush().await?;
|
||||
@ -156,7 +154,7 @@ impl PgConnection {
|
||||
}
|
||||
|
||||
// Wait and return the next message to be received from Postgres.
|
||||
pub(super) async fn receive(&mut self) -> crate::Result<Option<Message>> {
|
||||
pub(super) async fn receive(&mut self) -> Result<Option<Message>> {
|
||||
loop {
|
||||
// Read the message header (id + len)
|
||||
let mut header = ret_if_none!(self.stream.peek(5).await?);
|
||||
@ -222,7 +220,7 @@ impl PgConnection {
|
||||
}
|
||||
|
||||
impl PgConnection {
|
||||
pub(super) async fn open(url: crate::Result<Url>) -> crate::Result<Self> {
|
||||
pub(super) async fn open(url: Result<Url>) -> Result<Self> {
|
||||
let url = url?;
|
||||
let stream = TcpStream::connect((url.host(), url.port(5432))).await?;
|
||||
let mut self_ = Self {
|
||||
@ -242,7 +240,7 @@ impl PgConnection {
|
||||
}
|
||||
|
||||
impl Connection for PgConnection {
|
||||
fn open<T>(url: T) -> BoxFuture<'static, crate::Result<Self>>
|
||||
fn open<T>(url: T) -> BoxFuture<'static, Result<Self>>
|
||||
where
|
||||
T: TryInto<Url, Error = crate::Error>,
|
||||
Self: Sized,
|
||||
@ -250,7 +248,153 @@ impl Connection for PgConnection {
|
||||
Box::pin(PgConnection::open(url.try_into()))
|
||||
}
|
||||
|
||||
fn close(self) -> BoxFuture<'static, crate::Result<()>> {
|
||||
fn close(self) -> BoxFuture<'static, Result<()>> {
|
||||
Box::pin(self.terminate())
|
||||
}
|
||||
}
|
||||
|
||||
static GS2_HEADER: &'static str = "n,,";
|
||||
static CHANNEL_ATTR: &'static str = "c";
|
||||
static USERNAME_ATTR: &'static str = "n";
|
||||
static CLIENT_PROOF_ATTR: &'static str = "p";
|
||||
static NONCE_ATTR: &'static str = "r";
|
||||
|
||||
// Nonce generator
|
||||
// Nonce is a sequence of random printable bytes
|
||||
fn nonce() -> String {
|
||||
let mut rng = rand::thread_rng();
|
||||
let count = rng.gen_range(64, 128);
|
||||
// printable = %x21-2B / %x2D-7E
|
||||
// ;; Printable ASCII except ",".
|
||||
// ;; Note that any "printable" is also
|
||||
// ;; a valid "value".
|
||||
let nonce: String = std::iter::repeat(())
|
||||
.map(|()| {
|
||||
let mut c = rng.gen_range(0x21, 0x7F) as u8;
|
||||
|
||||
while c == 0x2C {
|
||||
c = rng.gen_range(0x21, 0x7F) as u8;
|
||||
}
|
||||
|
||||
c
|
||||
})
|
||||
.take(count)
|
||||
.map(|c| c as char)
|
||||
.collect();
|
||||
|
||||
rng.gen_range(32, 128);
|
||||
format!("{}={}", NONCE_ATTR, nonce)
|
||||
}
|
||||
|
||||
// Performs authenticiton using Simple Authentication Security Layer (SASL) which is what
|
||||
// Postgres uses
|
||||
async fn sasl_auth<T: AsRef<str>>(
|
||||
conn: &mut PgConnection,
|
||||
username: T,
|
||||
password: T,
|
||||
) -> Result<()> {
|
||||
// channel-binding = "c=" base64
|
||||
let channel_binding = format!("{}={}", CHANNEL_ATTR, base64::encode(GS2_HEADER));
|
||||
// "n=" saslname ;; Usernames are prepared using SASLprep.
|
||||
let username = format!("{}={}", USERNAME_ATTR, username.as_ref());
|
||||
// nonce = "r=" c-nonce [s-nonce] ;; Second part provided by server.
|
||||
let nonce = nonce();
|
||||
let client_first_message_bare =
|
||||
format!("{username},{nonce}", username = username, nonce = nonce);
|
||||
// client-first-message-bare = [reserved-mext ","] username "," nonce ["," extensions]
|
||||
let client_first_message = format!(
|
||||
"{gs2_header}{client_first_message_bare}",
|
||||
gs2_header = GS2_HEADER,
|
||||
client_first_message_bare = client_first_message_bare
|
||||
);
|
||||
|
||||
SaslInitialResponse(&client_first_message)
|
||||
.encode(conn.stream.buffer_mut());
|
||||
conn.stream.flush().await?;
|
||||
|
||||
let server_first_message = conn.receive().await?;
|
||||
|
||||
if let Some(Message::Authentication(auth)) = server_first_message {
|
||||
if let Authentication::SaslContinue(sasl) = *auth {
|
||||
let server_first_message = sasl.data;
|
||||
|
||||
// SaltedPassword := Hi(Normalize(password), salt, i)
|
||||
let salted_password = hi(password.as_ref(), &sasl.salt, sasl.iter_count)?;
|
||||
|
||||
// ClientKey := HMAC(SaltedPassword, "Client Key")
|
||||
let mut mac = Hmac::<Sha256>::new_varkey(&salted_password)
|
||||
.map_err(|_| protocol_err!("HMAC can take key of any size"))?;
|
||||
mac.input(b"Client Key");
|
||||
let client_key = mac.result().code();
|
||||
|
||||
// StoredKey := H(ClientKey)
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.input(client_key);
|
||||
let stored_key = hasher.result();
|
||||
|
||||
// String::from_utf8_lossy should never fail because Postgres requires
|
||||
// the nonce to be all printable characters except ','
|
||||
let client_final_message_wo_proof = format!(
|
||||
"{channel_binding},r={nonce}",
|
||||
channel_binding = channel_binding,
|
||||
nonce = String::from_utf8_lossy(&sasl.nonce)
|
||||
);
|
||||
|
||||
// AuthMessage := client-first-message-bare + "," + server-first-message + "," + client-final-message-without-proof
|
||||
let auth_message = format!("{client_first_message_bare},{server_first_message},{client_final_message_wo_proof}",
|
||||
client_first_message_bare = client_first_message_bare,
|
||||
server_first_message = server_first_message,
|
||||
client_final_message_wo_proof = client_final_message_wo_proof);
|
||||
|
||||
// ClientSignature := HMAC(StoredKey, AuthMessage)
|
||||
let mut mac =
|
||||
Hmac::<Sha256>::new_varkey(&stored_key).expect("HMAC can take key of any size");
|
||||
mac.input(&auth_message.as_bytes());
|
||||
let client_signature = mac.result().code();
|
||||
|
||||
// ClientProof := ClientKey XOR ClientSignature
|
||||
let client_proof: Vec<u8> = client_key
|
||||
.iter()
|
||||
.zip(client_signature.iter())
|
||||
.map(|(&a, &b)| a ^ b)
|
||||
.collect();
|
||||
|
||||
// ServerKey := HMAC(SaltedPassword, "Server Key")
|
||||
let mut mac = Hmac::<Sha256>::new_varkey(&salted_password)
|
||||
.map_err(|_| protocol_err!("HMAC can take key of any size"))?;
|
||||
mac.input(b"Server Key");
|
||||
let server_key = mac.result().code();
|
||||
|
||||
// ServerSignature := HMAC(ServerKey, AuthMessage)
|
||||
let mut mac =
|
||||
Hmac::<Sha256>::new_varkey(&server_key).expect("HMAC can take key of any size");
|
||||
mac.input(&auth_message.as_bytes());
|
||||
let _server_signature = mac.result().code();
|
||||
|
||||
// client-final-message = client-final-message-without-proof "," proof
|
||||
let client_final_message = format!(
|
||||
"{client_final_message_wo_proof},{client_proof_attr}={client_proof}",
|
||||
client_final_message_wo_proof = client_final_message_wo_proof,
|
||||
client_proof_attr = CLIENT_PROOF_ATTR,
|
||||
client_proof = base64::encode(&client_proof)
|
||||
);
|
||||
|
||||
SaslResponse(&client_final_message)
|
||||
.encode(conn.stream.buffer_mut());
|
||||
conn.stream.flush().await?;
|
||||
let _server_final_response = conn.receive().await?;
|
||||
|
||||
Ok(())
|
||||
} else {
|
||||
Err(protocol_err!(
|
||||
"Expected Authentication::SaslContinue, but received {:?}",
|
||||
auth
|
||||
))?
|
||||
}
|
||||
} else {
|
||||
Err(protocol_err!(
|
||||
"Expected Message::Authentication, but received {:?}",
|
||||
server_first_message
|
||||
))?
|
||||
}
|
||||
}
|
||||
|
@ -3,6 +3,7 @@ use crate::postgres::protocol::Decode;
|
||||
use byteorder::NetworkEndian;
|
||||
use std::borrow::Cow;
|
||||
use std::io;
|
||||
use std::str;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Authentication {
|
||||
@ -99,29 +100,44 @@ impl Decode for Authentication {
|
||||
let mut nonce: Vec<u8> = Vec::new();
|
||||
let mut iter_count: u32 = 0;
|
||||
|
||||
buf.split(|byte| *byte == b',')
|
||||
let key_value: Vec<(char, &[u8])> = buf
|
||||
.split(|byte| *byte == b',')
|
||||
.map(|s| {
|
||||
let (key, value) = s.split_at(1);
|
||||
let value = value.split_at(1).1;
|
||||
|
||||
(key[0] as char, value)
|
||||
})
|
||||
.for_each(|(key, value)| match key {
|
||||
.collect();
|
||||
|
||||
for (key, value) in key_value.iter() {
|
||||
match key {
|
||||
's' => salt = value.to_vec(),
|
||||
'r' => nonce = value.to_vec(),
|
||||
'i' => {
|
||||
iter_count = u32::from_str_radix(&String::from_utf8_lossy(&value), 10)
|
||||
.unwrap_or(0);
|
||||
let s = str::from_utf8(&value).map_err(|_| {
|
||||
protocol_err!(
|
||||
"iteration count in sasl response was not a valid utf8 string"
|
||||
)
|
||||
})?;
|
||||
iter_count = u32::from_str_radix(&s, 10).unwrap_or(0);
|
||||
}
|
||||
|
||||
_ => {}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Authentication::SaslContinue(SaslContinue {
|
||||
salt: base64::decode(&salt).unwrap(),
|
||||
salt: base64::decode(&salt).map_err(|_| {
|
||||
protocol_err!("salt value response from postgres was not base64 encoded")
|
||||
})?,
|
||||
nonce,
|
||||
iter_count,
|
||||
data: String::from_utf8_lossy(buf).into_owned(),
|
||||
data: str::from_utf8(buf)
|
||||
.map_err(|_| {
|
||||
protocol_err!("SaslContinue response was not a valid utf8 string")
|
||||
})?
|
||||
.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -31,7 +31,7 @@ pub use flush::Flush;
|
||||
pub use parse::Parse;
|
||||
pub use password_message::PasswordMessage;
|
||||
pub use query::Query;
|
||||
pub use sasl::{sasl_auth, SaslInitialResponse, SaslResponse};
|
||||
pub use sasl::{hi, SaslInitialResponse, SaslResponse};
|
||||
pub use startup_message::StartupMessage;
|
||||
pub use statement::StatementId;
|
||||
pub use sync::Sync;
|
||||
|
@ -6,186 +6,35 @@ use crate::postgres::protocol::Message;
|
||||
use crate::Result;
|
||||
use byteorder::NetworkEndian;
|
||||
use hmac::{Hmac, Mac};
|
||||
use rand::Rng;
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
pub struct SaslInitialResponse {
|
||||
// pub username: String,
|
||||
// pub passord: String,
|
||||
pub s: String,
|
||||
}
|
||||
pub struct SaslInitialResponse<'a>(pub &'a str);
|
||||
|
||||
impl Encode for SaslInitialResponse {
|
||||
impl<'a> Encode for SaslInitialResponse<'a> {
|
||||
fn encode(&self, buf: &mut Vec<u8>) {
|
||||
let len = self.0.as_bytes().len() as u32;
|
||||
buf.push(b'p');
|
||||
buf.put_u32::<NetworkEndian>(4u32 + self.s.as_str().as_bytes().len() as u32 + 14u32 + 4u32);
|
||||
buf.put_u32::<NetworkEndian>(4u32 + len + 14u32 + 4u32);
|
||||
buf.put_str_nul("SCRAM-SHA-256");
|
||||
buf.put_u32::<NetworkEndian>(self.s.as_str().as_bytes().len() as u32);
|
||||
buf.extend_from_slice(self.s.as_str().as_bytes());
|
||||
buf.put_u32::<NetworkEndian>(len);
|
||||
buf.extend_from_slice(self.0.as_bytes());
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SaslResponse {
|
||||
pub s: String,
|
||||
}
|
||||
pub struct SaslResponse<'a>(pub &'a str);
|
||||
|
||||
impl Encode for SaslResponse {
|
||||
impl<'a> Encode for SaslResponse<'a> {
|
||||
fn encode(&self, buf: &mut Vec<u8>) {
|
||||
buf.push(b'p');
|
||||
buf.put_u32::<NetworkEndian>(4u32 + self.s.as_str().as_bytes().len() as u32);
|
||||
buf.extend_from_slice(self.s.as_str().as_bytes());
|
||||
}
|
||||
}
|
||||
|
||||
static GS2_HEADER: &'static str = "n,,";
|
||||
static CHANNEL_ATTR: &'static str = "c";
|
||||
static USERNAME_ATTR: &'static str = "n";
|
||||
static CLIENT_PROOF_ATTR: &'static str = "p";
|
||||
static NONCE_ATTR: &'static str = "r";
|
||||
|
||||
pub fn nonce() -> String {
|
||||
let mut rng = rand::thread_rng();
|
||||
let count = rng.gen_range(64, 128);
|
||||
// printable = %x21-2B / %x2D-7E
|
||||
// ;; Printable ASCII except ",".
|
||||
// ;; Note that any "printable" is also
|
||||
// ;; a valid "value".
|
||||
let nonce: String = std::iter::repeat(())
|
||||
.map(|()| {
|
||||
let mut c = rng.gen_range(0x21, 0x7F) as u8;
|
||||
|
||||
while c == 0x2C {
|
||||
c = rng.gen_range(0x21, 0x7F) as u8;
|
||||
}
|
||||
|
||||
c
|
||||
})
|
||||
.take(count)
|
||||
.map(|c| c as char)
|
||||
.collect();
|
||||
|
||||
rng.gen_range(32, 128);
|
||||
format!("{}={}", NONCE_ATTR, nonce)
|
||||
}
|
||||
|
||||
pub async fn sasl_auth<T: AsRef<str>>(
|
||||
conn: &mut PgConnection,
|
||||
username: T,
|
||||
password: T,
|
||||
) -> Result<()> {
|
||||
// channel-binding = "c=" base64
|
||||
let channel_binding = format!("{}={}", CHANNEL_ATTR, base64::encode(GS2_HEADER));
|
||||
// "n=" saslname ;; Usernames are prepared using SASLprep.
|
||||
let username = format!("{}={}", USERNAME_ATTR, username.as_ref());
|
||||
// nonce = "r=" c-nonce [s-nonce] ;; Second part provided by server.
|
||||
let nonce = nonce();
|
||||
let client_first_message_bare =
|
||||
format!("{username},{nonce}", username = username, nonce = nonce);
|
||||
// client-first-message-bare = [reserved-mext ","] username "," nonce ["," extensions]
|
||||
let client_first_message = format!(
|
||||
"{gs2_header}{client_first_message_bare}",
|
||||
gs2_header = GS2_HEADER,
|
||||
client_first_message_bare = client_first_message_bare
|
||||
);
|
||||
|
||||
SaslInitialResponse {
|
||||
s: client_first_message,
|
||||
}
|
||||
.encode(conn.stream.buffer_mut());
|
||||
conn.stream.flush().await?;
|
||||
|
||||
let server_first_message = conn.receive().await?;
|
||||
|
||||
if let Some(Message::Authentication(auth)) = server_first_message {
|
||||
if let SaslContinue(sasl) = *auth {
|
||||
let server_first_message = sasl.data;
|
||||
|
||||
// SaltedPassword := Hi(Normalize(password), salt, i)
|
||||
let salted_password = hi(password.as_ref(), sasl.salt, sasl.iter_count);
|
||||
|
||||
// ClientKey := HMAC(SaltedPassword, "Client Key")
|
||||
let mut mac = Hmac::<Sha256>::new_varkey(&salted_password)
|
||||
.expect("HMAC can take key of any size");
|
||||
mac.input(b"Client Key");
|
||||
let client_key = mac.result().code();
|
||||
|
||||
// StoredKey := H(ClientKey)
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.input(client_key);
|
||||
let stored_key = hasher.result();
|
||||
|
||||
// String::from_utf8_lossy should never fail because Postgres requires
|
||||
// the nonce to be all printable characters except ','
|
||||
let client_final_message_wo_proof = format!(
|
||||
"{channel_binding},r={nonce}",
|
||||
channel_binding = channel_binding,
|
||||
nonce = String::from_utf8_lossy(&sasl.nonce)
|
||||
);
|
||||
|
||||
// AuthMessage := client-first-message-bare + "," + server-first-message + "," + client-final-message-without-proof
|
||||
let auth_message = format!("{client_first_message_bare},{server_first_message},{client_final_message_wo_proof}",
|
||||
client_first_message_bare = client_first_message_bare,
|
||||
server_first_message = server_first_message,
|
||||
client_final_message_wo_proof = client_final_message_wo_proof);
|
||||
|
||||
// ClientSignature := HMAC(StoredKey, AuthMessage)
|
||||
let mut mac =
|
||||
Hmac::<Sha256>::new_varkey(&stored_key).expect("HMAC can take key of any size");
|
||||
mac.input(&auth_message.as_bytes());
|
||||
let client_signature = mac.result().code();
|
||||
|
||||
// ClientProof := ClientKey XOR ClientSignature
|
||||
let client_proof: Vec<u8> = client_key
|
||||
.iter()
|
||||
.zip(client_signature.iter())
|
||||
.map(|(&a, &b)| a ^ b)
|
||||
.collect();
|
||||
|
||||
// ServerKey := HMAC(SaltedPassword, "Server Key")
|
||||
let mut mac = Hmac::<Sha256>::new_varkey(&salted_password)
|
||||
.expect("HMAC can take key of any size");
|
||||
mac.input(b"Server Key");
|
||||
let server_key = mac.result().code();
|
||||
|
||||
// ServerSignature := HMAC(ServerKey, AuthMessage)
|
||||
let mut mac =
|
||||
Hmac::<Sha256>::new_varkey(&server_key).expect("HMAC can take key of any size");
|
||||
mac.input(&auth_message.as_bytes());
|
||||
let server_signature = mac.result().code();
|
||||
|
||||
// client-final-message = client-final-message-without-proof "," proof
|
||||
let client_final_message = format!(
|
||||
"{client_final_message_wo_proof},p={client_proof}",
|
||||
client_final_message_wo_proof = client_final_message_wo_proof,
|
||||
client_proof = base64::encode(&client_proof)
|
||||
);
|
||||
|
||||
SaslResponse {
|
||||
s: client_final_message,
|
||||
}
|
||||
.encode(conn.stream.buffer_mut());
|
||||
conn.stream.flush().await?;
|
||||
let server_final_response = conn.receive().await?;
|
||||
|
||||
Ok(())
|
||||
} else {
|
||||
Err(protocol_err!(
|
||||
"Expected Authentication::SaslContinue, but received {:?}",
|
||||
auth
|
||||
))?
|
||||
}
|
||||
} else {
|
||||
Err(protocol_err!(
|
||||
"Expected Message::Authentication, but received {:?}",
|
||||
server_first_message
|
||||
))?
|
||||
buf.put_u32::<NetworkEndian>(4u32 + self.0.as_bytes().len() as u32);
|
||||
buf.extend_from_slice(self.0.as_bytes());
|
||||
}
|
||||
}
|
||||
|
||||
// Hi(str, salt, i):
|
||||
pub fn hi<T: AsRef<str>>(s: T, salt: Vec<u8>, iter_count: u32) -> Vec<u8> {
|
||||
let mut mac =
|
||||
Hmac::<Sha256>::new_varkey(s.as_ref().as_bytes()).expect("HMAC can take key of any size");
|
||||
pub fn hi<'a>(s: &'a str, salt: &'a [u8], iter_count: u32) -> Result<[u8; 32]> {
|
||||
let mut mac = Hmac::<Sha256>::new_varkey(s.as_bytes())
|
||||
.map_err(|_| protocol_err!("HMAC can take key of any size"))?;
|
||||
|
||||
mac.input(&salt);
|
||||
mac.input(&1u32.to_be_bytes());
|
||||
@ -194,12 +43,12 @@ pub fn hi<T: AsRef<str>>(s: T, salt: Vec<u8>, iter_count: u32) -> Vec<u8> {
|
||||
let mut hi = u;
|
||||
|
||||
for _ in 1..iter_count {
|
||||
let mut mac = Hmac::<Sha256>::new_varkey(s.as_ref().as_bytes())
|
||||
.expect("HMAC can take key of any size");
|
||||
let mut mac = Hmac::<Sha256>::new_varkey(s.as_bytes())
|
||||
.map_err(|_| protocol_err!(" HMAC can take key of any size"))?;
|
||||
mac.input(u.as_slice());
|
||||
u = mac.result().code();
|
||||
hi = hi.iter().zip(u.iter()).map(|(&a, &b)| a ^ b).collect();
|
||||
}
|
||||
|
||||
hi.to_vec()
|
||||
Ok(hi.into())
|
||||
}
|
||||
|
@ -14,30 +14,6 @@ async fn it_connects() -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// #[async_std::test]
|
||||
// async fn it_connects_to_database_user() -> anyhow::Result<()> {
|
||||
// let mut conn = connect().await?;
|
||||
|
||||
// let row = sqlx::query("select current_database()")
|
||||
// .fetch_one(&mut conn)
|
||||
// .await?;
|
||||
|
||||
// let current_db: String = row.get(0);
|
||||
|
||||
// let row = sqlx::query("select current_user")
|
||||
// .fetch_one(&mut conn)
|
||||
// .await?;
|
||||
|
||||
// let current_user: String = row.get(0);
|
||||
|
||||
// assert_eq!(current_db, "postgres");
|
||||
// assert_eq!(current_user, "postgres");
|
||||
|
||||
// conn.close().await?;
|
||||
|
||||
// Ok(())
|
||||
// }
|
||||
|
||||
#[async_std::test]
|
||||
async fn it_executes() -> anyhow::Result<()> {
|
||||
let mut conn = connect().await?;
|
||||
|
Loading…
x
Reference in New Issue
Block a user