refactor(postgres): baseline postgres driver against the now near-complete state of the mysql driver

This commit is contained in:
Ryan Leckey 2021-03-06 13:59:24 -08:00
parent 39e2658537
commit baa63d33e1
No known key found for this signature in database
GPG Key ID: F8AA68C235AB08C9
46 changed files with 721 additions and 2335 deletions

76
Cargo.lock generated
View File

@ -379,16 +379,6 @@ dependencies = [
"lazy_static",
]
[[package]]
name = "crypto-mac"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4857fd85a0c34b3c3297875b747c1e02e06b6a0ea32dd892d8192b9ce0813ea6"
dependencies = [
"generic-array",
"subtle",
]
[[package]]
name = "ctor"
version = "0.1.19"
@ -565,16 +555,6 @@ dependencies = [
"libc",
]
[[package]]
name = "hmac"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1441c6b1e930e2817404b5046f1f989899143a12bf92de603b69f4e0aee1e15"
dependencies = [
"crypto-mac",
"digest",
]
[[package]]
name = "idna"
version = "0.2.2"
@ -595,12 +575,6 @@ dependencies = [
"cfg-if 1.0.0",
]
[[package]]
name = "itoa"
version = "0.4.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd25036021b0de88a0aff6b850051563c6516d0bf53f8638938edbb9de732736"
[[package]]
name = "js-sys"
version = "0.3.48"
@ -630,9 +604,9 @@ dependencies = [
[[package]]
name = "libc"
version = "0.2.87"
version = "0.2.88"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "265d751d31d6780a3f956bb5b8022feba2d94eeee5a84ba64f4212eedca42213"
checksum = "03b07a082330a35e43f63177cc01689da34fbffa0105e1246cf0311472cac73a"
[[package]]
name = "libm"
@ -665,17 +639,6 @@ version = "0.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ffc5c5338469d4d3ea17d269fa8ea3512ad247247c30bd2df69e68309ed0a08"
[[package]]
name = "md-5"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b5a279bb9607f9f53c22d496eade00d138d1bdcccd07d74650387cf94942a15"
dependencies = [
"block-buffer",
"digest",
"opaque-debug",
]
[[package]]
name = "memchr"
version = "2.3.4"
@ -865,9 +828,9 @@ checksum = "d4fd5641d01c8f18a23da7b6fe29298ff4b55afcccdf78973b24cf3175fee32e"
[[package]]
name = "pin-project-lite"
version = "0.2.5"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0cf491442e4b033ed1c722cb9f0df5fcfcf4de682466c46469c36bc47dc5548a"
checksum = "dc0e1f259c92177c30a4c9d177246edd0a3568b25756a977d0632cf8fa37e905"
[[package]]
name = "pin-utils"
@ -1019,9 +982,9 @@ checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd"
[[package]]
name = "serde"
version = "1.0.123"
version = "1.0.124"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "92d5161132722baa40d802cc70b15262b98258453e85e5d1d365c757c73869ae"
checksum = "bd761ff957cb2a45fbb9ab3da6512de9de55872866160b23c25f1a841e99d29f"
[[package]]
name = "sha-1"
@ -1159,41 +1122,20 @@ version = "0.6.0-pre"
dependencies = [
"anyhow",
"atoi",
"base64",
"bitflags",
"byteorder",
"bytes",
"bytestring",
"crypto-mac",
"either",
"conquer-once",
"futures-executor",
"futures-io",
"futures-util",
"hmac",
"itoa",
"log",
"md-5",
"memchr",
"percent-encoding",
"rand",
"rsa",
"sha-1",
"sha2",
"sqlx-core",
"stringprep",
"url",
]
[[package]]
name = "stringprep"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ee348cb74b87454fff4b551cbf727025810a004f88aeacae7f85b87f4e9a1c1"
dependencies = [
"unicode-bidi",
"unicode-normalization",
]
[[package]]
name = "subtle"
version = "2.4.0"
@ -1202,9 +1144,9 @@ checksum = "1e81da0851ada1f3e9d4312c704aa4f8806f0f9d69faaf8df2f3464b4a9437c2"
[[package]]
name = "syn"
version = "1.0.60"
version = "1.0.62"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c700597eca8a5a762beb35753ef6b94df201c81cca676604f547495a0d7f0081"
checksum = "123a78a3596b24fee53a6464ce52d8ecbf62241e6294c7e7fe12086cd161f512"
dependencies = [
"proc-macro2",
"quote",

View File

@ -9,5 +9,5 @@ pub trait RawValue<'r>: Sized {
fn is_null(&self) -> bool;
/// Returns the type information for this value.
fn type_info(&self) -> &'r <Self::Database as Database>::TypeInfo;
fn type_info(&self) -> &<Self::Database as Database>::TypeInfo;
}

View File

@ -23,7 +23,6 @@ default = []
blocking = ["sqlx-core/blocking"]
# async runtime
# not meant to be used directly
async = ["futures-util", "sqlx-core/async", "futures-io"]
[dependencies]

View File

@ -12,7 +12,7 @@ use crate::protocol::{Capabilities, Info, Status};
/// An OK packet is sent from the server to the client to signal successful completion of a command.
/// As of MySQL 5.7.5, OK packes are also used to indicate EOF, and EOF packets are deprecated.
#[allow(clippy::module_name_repetitions)]
#[derive(Debug)]
#[derive(Debug, Clone)]
pub(crate) struct OkPacket {
pub(crate) affected_rows: u64,
pub(crate) last_insert_id: u64,

View File

@ -4,11 +4,12 @@ use sqlx_core::QueryResult;
use crate::protocol::{Info, OkPacket, Status};
/// Represents the execution result of an operation on the database server.
/// Represents the execution result of an operation in MySQL.
///
/// Returned from [`execute()`][sqlx_core::Executor::execute].
///
#[allow(clippy::module_name_repetitions)]
#[derive(Clone)]
pub struct MySqlQueryResult(pub(crate) OkPacket);
impl MySqlQueryResult {
@ -109,18 +110,6 @@ impl Debug for MySqlQueryResult {
}
}
impl Default for MySqlQueryResult {
fn default() -> Self {
Self(OkPacket {
affected_rows: 0,
last_insert_id: 0,
status: Status::empty(),
warnings: 0,
info: Info::default(),
})
}
}
impl From<OkPacket> for MySqlQueryResult {
fn from(ok: OkPacket) -> Self {
Self(ok)

View File

@ -87,7 +87,7 @@ impl<'r> RawValue<'r> for MySqlRawValue<'r> {
self.value.is_none()
}
fn type_info(&self) -> &'r MySqlTypeInfo {
fn type_info(&self) -> &MySqlTypeInfo {
self.type_info
}
}

View File

@ -101,7 +101,7 @@ impl MySqlTypeInfo {
MySqlTypeId::CHAR => "CHAR",
MySqlTypeId::TEXT => "TEXT",
_ => "",
_ => "UNKNOWN",
}
}
}

View File

@ -2,7 +2,7 @@
name = "sqlx-postgres"
version = "0.6.0-pre"
repository = "https://github.com/launchbadge/sqlx"
description = "MySQL database driver for SQLx, the Rust SQL Toolkit."
description = "PostgreSQL database driver for SQLx, the Rust SQL Toolkit."
license = "MIT OR Apache-2.0"
edition = "2018"
keywords = ["postgres", "sqlx", "database"]
@ -23,13 +23,12 @@ default = []
blocking = ["sqlx-core/blocking"]
# async runtime
# not meant to be used directly
async = ["futures-util", "sqlx-core/async", "futures-io"]
[dependencies]
atoi = "0.4.0"
sqlx-core = { version = "0.6.0-pre", path = "../sqlx-core" }
futures-util = { version = "0.3.8", optional = true }
either = "1.6.1"
log = "0.4.11"
bytestring = "1.0.0"
url = "2.2.0"
@ -38,20 +37,9 @@ futures-io = { version = "0.3", optional = true }
bytes = "1.0"
memchr = "2.3"
bitflags = "1.2"
sha-1 = "0.9.2"
sha2 = "0.9.2"
rsa = "0.3.0"
base64 = "0.13.0"
rand = "0.7"
itoa = "0.4.7"
atoi = "0.4.0"
byteorder = { version = "1.4.2", default-features = false, features = [ "std" ] }
md-5 = { version = "0.9.1", default-features = false }
hmac = { version = "0.10.1", default-features = false }
stringprep = "0.1.2"
crypto-mac = "0.10.0"
[dev-dependencies]
sqlx-core = { version = "0.6.0-pre", path = "../sqlx-core", features = ["_mock"] }
futures-executor = "0.3.8"
anyhow = "1.0.37"
conquer-once = "0.3.2"

View File

@ -0,0 +1,31 @@
use bytestring::ByteString;
use sqlx_core::{Column, Database};
use crate::{PgTypeInfo, Postgres};
// TODO: inherent methods from <Column>
/// Represents a column from a query in Postgres.
#[allow(clippy::module_name_repetitions)]
#[derive(Debug, Clone)]
pub struct PgColumn {
index: usize,
name: ByteString,
type_info: PgTypeInfo,
}
impl Column for PgColumn {
type Database = Postgres;
fn name(&self) -> &str {
&self.name
}
fn index(&self) -> usize {
self.index
}
fn type_info(&self) -> &PgTypeInfo {
&self.type_info
}
}

View File

@ -1,124 +0,0 @@
use std::fmt::{self, Debug, Formatter};
use sqlx_core::io::BufStream;
use sqlx_core::net::Stream as NetStream;
use sqlx_core::{Close, Connect, Connection, Runtime};
use crate::{Postgres, PostgresConnectOptions};
#[macro_use]
mod sasl;
mod close;
mod connect;
mod ping;
mod stream;
/// A single connection (also known as a session) to a PostgreSQL database server.
#[allow(clippy::module_name_repetitions)]
pub struct PostgresConnection<Rt>
where
Rt: Runtime,
{
stream: BufStream<Rt, NetStream<Rt>>,
// process id of this backend
// used to send cancel requests
#[allow(dead_code)]
process_id: u32,
// secret key of this backend
// used to send cancel requests
#[allow(dead_code)]
secret_key: u32,
}
impl<Rt> PostgresConnection<Rt>
where
Rt: Runtime,
{
pub(crate) fn new(stream: NetStream<Rt>) -> Self {
Self { stream: BufStream::with_capacity(stream, 4096, 1024), process_id: 0, secret_key: 0 }
}
}
impl<Rt> Debug for PostgresConnection<Rt>
where
Rt: Runtime,
{
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("PostgresConnection").finish()
}
}
impl<Rt> Connection<Rt> for PostgresConnection<Rt>
where
Rt: Runtime,
{
type Database = Postgres;
#[cfg(feature = "async")]
fn ping(&mut self) -> futures_util::future::BoxFuture<'_, sqlx_core::Result<()>>
where
Rt: sqlx_core::Async,
{
Box::pin(self.ping_async())
}
}
impl<Rt: Runtime> Connect<Rt> for PostgresConnection<Rt> {
type Options = PostgresConnectOptions<Rt>;
#[cfg(feature = "async")]
fn connect(url: &str) -> futures_util::future::BoxFuture<'_, sqlx_core::Result<Self>>
where
Self: Sized,
Rt: sqlx_core::Async,
{
use sqlx_core::ConnectOptions;
let options = url.parse::<Self::Options>();
Box::pin(async move { options?.connect().await })
}
}
impl<Rt: Runtime> Close<Rt> for PostgresConnection<Rt> {
#[cfg(feature = "async")]
fn close(self) -> futures_util::future::BoxFuture<'static, sqlx_core::Result<()>>
where
Rt: sqlx_core::Async,
{
Box::pin(self.close_async())
}
}
#[cfg(feature = "blocking")]
mod blocking {
use sqlx_core::blocking::{Close, Connect, Connection, Runtime};
use super::{PostgresConnectOptions, PostgresConnection};
impl<Rt: Runtime> Connection<Rt> for PostgresConnection<Rt> {
#[inline]
fn ping(&mut self) -> sqlx_core::Result<()> {
self.ping()
}
}
impl<Rt: Runtime> Connect<Rt> for PostgresConnection<Rt> {
#[inline]
fn connect(url: &str) -> sqlx_core::Result<Self>
where
Self: Sized,
{
Self::connect(&url.parse::<PostgresConnectOptions<Rt>>()?)
}
}
impl<Rt: Runtime> Close<Rt> for PostgresConnection<Rt> {
#[inline]
fn close(self) -> sqlx_core::Result<()> {
self.close()
}
}
}

View File

@ -1,32 +0,0 @@
use sqlx_core::{io::Stream, Result, Runtime};
use crate::protocol::Terminate;
impl<Rt> super::PostgresConnection<Rt>
where
Rt: Runtime,
{
#[cfg(feature = "async")]
pub(crate) async fn close_async(mut self) -> Result<()>
where
Rt: sqlx_core::Async,
{
self.write_packet(&Terminate)?;
self.stream.flush_async().await?;
self.stream.shutdown_async().await?;
Ok(())
}
#[cfg(feature = "blocking")]
pub(crate) fn close(mut self) -> Result<()>
where
Rt: sqlx_core::blocking::Runtime,
{
self.write_packet(&Terminate)?;
self.stream.flush()?;
self.stream.shutdown()?;
Ok(())
}
}

View File

@ -1,180 +0,0 @@
//! Implements the connection phase.
//!
//! The connection phase (establish) performs these tasks:
//!
//! - exchange the capabilities of client and server
//! - setup SSL communication channel if requested
//! - authenticate the client against the server
//!
//! The server may immediately send an ERR packet and finish the handshake
//! or send a `Handshake`.
//!
//! https://dev.postgres.com/doc/internals/en/connection-phase.html
//!
use hmac::{Hmac, Mac, NewMac};
use sha2::{Digest, Sha256};
use sqlx_core::net::Stream as NetStream;
use sqlx_core::Error;
use sqlx_core::Result;
use crate::protocol::{
Authentication, BackendKeyData, Message, MessageType, Password, ReadyForQuery,
SaslInitialResponse, SaslResponse, Startup,
};
use crate::{PostgresConnectOptions, PostgresConnection, PostgresDatabaseError};
macro_rules! connect {
(@blocking @tcp $options:ident) => {
NetStream::connect($options.address.as_ref())?;
};
(@tcp $options:ident) => {
NetStream::connect_async($options.address.as_ref()).await?;
};
(@blocking @packet $self:ident) => {
$self.read_packet()?;
};
(@packet $self:ident) => {
$self.read_packet_async().await?;
};
($(@$blocking:ident)? $options:ident) => {{
// open a network stream to the database server
let stream = connect!($(@$blocking)? @tcp $options);
// construct a <PostgresConnection> around the network stream
// wraps the stream in a <BufStream> to buffer read and write
let mut self_ = Self::new(stream);
// To begin a session, a frontend opens a connection to the server
// and sends a startup message.
let mut params = vec![ // Sets the display format for date and time values, as well as the rules for interpreting ambiguous date input values. ("DateStyle", "ISO, MDY"),
// Sets the client-side encoding (character set).
// <https://www.postgresql.org/docs/devel/multibyte.html#MULTIBYTE-CHARSET-SUPPORTED>
("client_encoding", "UTF8"),
// Sets the time zone for displaying and interpreting time stamps.
("TimeZone", "UTC"),
// Adjust postgres to return precise values for floats
// NOTE: This is default in postgres 12+
("extra_float_digits", "3"),
];
// if let Some(ref application_name) = $options.get_application_name() {
// params.push(("application_name", application_name));
// }
self_.write_packet(&Startup {
username: $options.get_username(),
database: $options.get_database(),
params: &params,
})?;
// The server then uses this information and the contents of
// its configuration files (such as pg_hba.conf) to determine whether the connection is
// provisionally acceptable, and what additional
// authentication is required (if any).
let mut process_id = 0;
let mut secret_key = 0;
let transaction_status;
loop {
let message: Message = connect!($(@$blocking)? @packet self_);
match message.r#type {
MessageType::Authentication => match message.decode()? {
Authentication::Ok => {
// the authentication exchange is successfully completed
// do nothing; no more information is required to continue
}
Authentication::CleartextPassword => {
// The frontend must now send a [PasswordMessage] containing the
// password in clear-text form.
self_
.write_packet(&Password::Cleartext(
$options.get_password().unwrap_or_default(),
))?;
}
Authentication::Md5Password(body) => {
// The frontend must now send a [PasswordMessage] containing the
// password (with user name) encrypted via MD5, then encrypted again
// using the 4-byte random salt specified in the
// [AuthenticationMD5Password] message.
self_
.write_packet(&Password::Md5 {
username: $options.get_username().unwrap_or_default(),
password: $options.get_password().unwrap_or_default(),
salt: body.salt,
})?;
}
Authentication::Sasl(body) => {
sasl_authenticate!($(@$blocking)? self_, $options, body)
}
method => {
return Err(Error::configuration_msg(format!(
"unsupported authentication method: {:?}",
method
)));
}
},
MessageType::BackendKeyData => {
// provides secret-key data that the frontend must save if it wants to be
// able to issue cancel requests later
let data: BackendKeyData = message.decode()?;
process_id = data.process_id;
secret_key = data.secret_key;
}
MessageType::ReadyForQuery => {
let ready: ReadyForQuery = message.decode()?;
// start-up is completed. The frontend can now issue commands
transaction_status = ready.transaction_status;
break;
}
_ => {
return Err(Error::configuration_msg(format!(
"establish: unexpected message: {:?}",
message.r#type
)))
}
}
}
Ok(self_)
}};
}
impl<Rt> PostgresConnection<Rt>
where
Rt: sqlx_core::Runtime,
{
#[cfg(feature = "async")]
pub(crate) async fn connect_async(options: &PostgresConnectOptions<Rt>) -> Result<Self>
where
Rt: sqlx_core::Async,
{
connect!(options)
}
#[cfg(feature = "blocking")]
pub(crate) fn connect(options: &PostgresConnectOptions<Rt>) -> Result<Self>
where
Rt: sqlx_core::blocking::Runtime,
{
connect!(@blocking options)
}
}

View File

@ -1,22 +0,0 @@
use sqlx_core::{Result, Runtime};
impl<Rt> super::PostgresConnection<Rt>
where
Rt: Runtime,
{
#[cfg(feature = "async")]
pub(crate) async fn ping_async(&mut self) -> Result<()>
where
Rt: sqlx_core::Async,
{
todo!();
}
#[cfg(feature = "blocking")]
pub(crate) fn ping(&mut self) -> Result<()>
where
Rt: sqlx_core::blocking::Runtime,
{
todo!();
}
}

View File

@ -1,294 +0,0 @@
use hmac::{Hmac, Mac, NewMac};
use rand::Rng;
use sha2::digest::Digest;
use sha2::Sha256;
use sqlx_core::Error;
use sqlx_core::Result;
use sqlx_core::Runtime;
use crate::protocol::{
Authentication, AuthenticationSasl, Message, MessageType, SaslInitialResponse, SaslResponse,
};
use crate::PostgresConnectOptions;
use crate::PostgresConnection;
use crate::PostgresDatabaseError;
pub(super) const GS2_HEADER: &str = "n,,";
pub(super) const CHANNEL_ATTR: &str = "c";
pub(super) const USERNAME_ATTR: &str = "n";
pub(super) const CLIENT_PROOF_ATTR: &str = "p";
pub(super) const NONCE_ATTR: &str = "r";
pub(super) fn sasl_init_response<Rt: Runtime>(
self_: &mut PostgresConnection<Rt>,
options: &PostgresConnectOptions<Rt>,
data: AuthenticationSasl,
) -> Result<(String, String, String)> {
let mut has_sasl = false;
let mut has_sasl_plus = false;
let mut unknown = Vec::new();
for mechanism in data.mechanisms() {
match mechanism {
"SCRAM-SHA-256" => {
has_sasl = true;
}
"SCRAM-SHA-256-PLUS" => {
has_sasl_plus = true;
}
_ => {
unknown.push(mechanism.to_owned());
}
}
}
if !has_sasl_plus && !has_sasl {
return Err(Error::connect(PostgresDatabaseError::protocol(format!(
"unsupported SASL authentication mechanisms: {}",
unknown.join(", ")
))));
}
// channel-binding = "c=" base64
let channel_binding = format!(
"{}={}",
crate::connection::sasl::CHANNEL_ATTR,
base64::encode(crate::connection::sasl::GS2_HEADER)
);
// "n=" saslname ;; Usernames are prepared using SASLprep.
let username = format!("{}={}", USERNAME_ATTR, options.get_username().unwrap_or_default());
let username = match stringprep::saslprep(&username) {
Ok(v) => v,
Err(err) => {
return Err(Error::connect(PostgresDatabaseError::protocol(format!(
"failed to sasl prep the username: {:?}",
err
))));
}
};
// nonce = "r=" c-nonce [s-nonce] ;; Second part provided by server.
let nonce = gen_nonce();
// client-first-message-bare = [reserved-mext ","] username "," nonce ["," extensions]
let client_first_message_bare =
format!("{username},{nonce}", username = username, nonce = nonce);
let client_first_message = format!(
"{gs2_header}{client_first_message_bare}",
gs2_header = GS2_HEADER,
client_first_message_bare = client_first_message_bare
);
self_.write_packet(&SaslInitialResponse {
response: client_first_message.clone(),
plus: false,
})?;
Ok((channel_binding, client_first_message_bare, client_first_message))
}
pub(super) fn sasl_response<'a, Rt: Runtime>(
self_: &mut PostgresConnection<Rt>,
message: Message,
options: &PostgresConnectOptions<Rt>,
channel_binding: &String,
client_first_message_bare: &'a String,
) -> Result<Hmac<Sha256>> {
let cont = match message.r#type {
MessageType::Authentication => match message.decode()? {
Authentication::SaslContinue(data) => data,
auth => {
return Err(Error::connect(PostgresDatabaseError::protocol(format!(
"expected SASLContinue but received {:?}",
auth
))));
}
},
r#type => {
return Err(Error::connect(PostgresDatabaseError::protocol(format!(
"Expected an authencation message type, found {:?}",
r#type
))));
}
};
// SaltedPassword := Hi(Normalize(password), salt, i)
let salted_password =
hi(options.get_password().unwrap_or_default(), &cont.salt, cont.iterations)?;
// ClientKey := HMAC(SaltedPassword, "Client Key")
let mut mac = Hmac::<Sha256>::new_varkey(&salted_password)
.map_err(|err| Error::connect(crate::PostgresDatabaseError::from(err)))?;
mac.update(b"Client Key");
let client_key = mac.finalize().into_bytes();
// StoredKey := H(ClientKey)
let stored_key = Sha256::digest(&client_key);
// client-final-message-without-proof
let client_final_message_wo_proof = format!(
"{channel_binding},r={nonce}",
channel_binding = channel_binding,
nonce = &cont.nonce
);
// AuthMessage := client-first-message-bare + "," + server-first-message + "," + client-final-message-without-proof
let auth_message = format!(
"{client_first_message_bare},{server_first_message},{client_final_message_wo_proof}",
client_first_message_bare = client_first_message_bare,
server_first_message = cont.message,
client_final_message_wo_proof = client_final_message_wo_proof
);
// ClientSignature := HMAC(StoredKey, AuthMessage)
let mut mac = Hmac::<Sha256>::new_varkey(&stored_key)
.map_err(|err| Error::connect(crate::PostgresDatabaseError::from(err)))?;
mac.update(&auth_message.as_bytes());
let client_signature = mac.finalize().into_bytes();
// ClientProof := ClientKey XOR ClientSignature
let client_proof: Vec<u8> =
client_key.iter().zip(client_signature.iter()).map(|(&a, &b)| a ^ b).collect();
// ServerKey := HMAC(SaltedPassword, "Server Key")
let mut mac = Hmac::<Sha256>::new_varkey(&salted_password)
.map_err(|err| Error::connect(crate::PostgresDatabaseError::from(err)))?;
mac.update(b"Server Key");
let server_key = mac.finalize().into_bytes();
// ServerSignature := HMAC(ServerKey, AuthMessage)
let mut mac = Hmac::<Sha256>::new_varkey(&server_key)
.map_err(|err| Error::connect(crate::PostgresDatabaseError::from(err)))?;
mac.update(&auth_message.as_bytes());
// client-final-message = client-final-message-without-proof "," proof
let client_final_message = format!(
"{client_final_message_wo_proof},{client_proof_attr}={client_proof}",
client_final_message_wo_proof = client_final_message_wo_proof,
client_proof_attr = crate::connection::sasl::CLIENT_PROOF_ATTR,
client_proof = base64::encode(&client_proof)
);
self_.write_packet(&SaslResponse(client_first_message_bare))?;
Ok(mac)
}
pub(super) fn sasl_final(message: Message, mac: Hmac<Sha256>) -> Result<()> {
let data = match message.r#type {
MessageType::Authentication => match message.decode()? {
Authentication::SaslFinal(data) => data,
auth => {
return Err(Error::connect(PostgresDatabaseError::protocol(format!(
"expected SASLContinue but received {:?}",
auth
))));
}
},
r#type => {
return Err(Error::connect(PostgresDatabaseError::protocol(format!(
"Expected an authencation message type, found {:?}",
r#type
))));
}
};
// authentication is only considered valid if this verification passes
mac.verify(&data.verifier)
.map_err(|err| Error::connect(crate::PostgresDatabaseError::from(err)))?;
Ok(())
}
macro_rules! sasl_authenticate {
(@blocking @packet $self:ident) => {
$self.read_packet()?;
};
(@packet $self:ident) => {
$self.read_packet_async().await?;
};
($(@$blocking:ident)? $self:ident, $options:ident, $data:ident) => {{
let (
channel_binding,
client_first_message_bare,
client_first_message,
) = crate::connection::sasl::sasl_init_response(&mut $self, $options, $data)?;
let message: Message = sasl_authenticate!($(@$blocking)? @packet $self);
let mac = crate::connection::sasl::sasl_response(
&mut $self,
message,
$options,
&channel_binding,
&client_first_message_bare,
)?;
let message: Message = sasl_authenticate!($(@$blocking)? @packet $self);
crate::connection::sasl::sasl_final(
message,
mac
)?;
}};
}
// nonce is a sequence of random printable bytes
fn gen_nonce() -> String {
let mut rng = rand::thread_rng();
let count = rng.gen_range(64, 128);
// printable = %x21-2B / %x2D-7E
// ;; Printable ASCII except ",".
// ;; Note that any "printable" is also
// ;; a valid "value".
let nonce: String = std::iter::repeat(())
.map(|()| {
let mut c = rng.gen_range(0x21, 0x7F) as u8;
while c == 0x2C {
c = rng.gen_range(0x21, 0x7F) as u8;
}
c
})
.take(count)
.map(|c| c as char)
.collect();
rng.gen_range(32, 128);
format!("{}={}", crate::connection::sasl::NONCE_ATTR, nonce)
}
// Hi(str, salt, i):
fn hi<'a>(s: &'a str, salt: &'a [u8], iter_count: u32) -> Result<[u8; 32]> {
let mut mac = Hmac::<Sha256>::new_varkey(s.as_bytes())
.map_err(|err| Error::connect(crate::PostgresDatabaseError::from(err)))?;
mac.update(&salt);
mac.update(&1u32.to_be_bytes());
let mut u = mac.finalize().into_bytes();
let mut hi = u;
for _ in 1..iter_count {
let mut mac = Hmac::<Sha256>::new_varkey(s.as_bytes())
.map_err(|err| Error::connect(crate::PostgresDatabaseError::from(err)))?;
mac.update(u.as_slice());
u = mac.finalize().into_bytes();
hi = hi.iter().zip(u.iter()).map(|(&a, &b)| a ^ b).collect();
}
Ok(hi.into())
}

View File

@ -1,162 +0,0 @@
//! Reads and writes packets to and from the PostgreSQL database server.
//!
//! The logic for serializing data structures into the packets is found
//! mostly in `protocol/`.
//!
//! Packets in PostgreSQL are prefixed by 4 bytes.
//! 3 for length (in LE) and a sequence id.
//!
//! Packets may only be as large as the communicated size in the initial
//! `HandshakeResponse`. By default, SQLx configures its chunk size to 16M. Sending
//! a larger payload is simply sending completely "full" packets, one after the
//! other, with an increasing sequence id.
//!
//! In other words, when we sent data, we:
//!
//! - Split the data into "packets" of size `2 ** 24 - 1` bytes.
//!
//! - Prepend each packet with a **packet header**, consisting of the length of that packet,
//! and the sequence number.
//!
//! https://dev.postgres.com/doc/internals/en/postgres-packet.html
//!
use std::convert::TryFrom;
use std::fmt::Debug;
use bytes::{Buf, Bytes};
use log::Level;
use sqlx_core::io::{Deserialize, Serialize};
use sqlx_core::{Result, Runtime};
use crate::protocol::{Message, MessageType, Notice, PgSeverity};
use crate::PostgresConnection;
impl<Rt> PostgresConnection<Rt>
where
Rt: Runtime,
{
pub(super) fn write_packet<'ser, T>(&'ser mut self, packet: &T) -> Result<()>
where
T: Serialize<'ser, ()> + Debug,
{
log::trace!("write > {:?}", packet);
let buf = self.stream.buffer();
packet.serialize_with(buf, ())?;
Ok(())
}
}
macro_rules! read_packet {
($(@$blocking:ident)? $self:ident) => {{
loop {
read_packet!($(@$blocking)? @stream $self, 0, 5);
// peek at the messaage type and payload size
let r#type = MessageType::try_from(*$self.stream.get(0, 1))?;
let size = (u32::from_be_bytes($self.stream.get(1, 4)) - 4) as usize;
read_packet!($(@$blocking)? @stream $self, 5, size);
// take the whole packet
$self.stream.consume(5);
let contents = $self.stream.take(size);
let message = Message { r#type, contents };
match message.r#type {
MessageType::ErrorResponse => {
// An error returned from the database server.
// return Err(PgDatabaseError(message.decode()?).into());
panic!("got error response");
}
MessageType::NotificationResponse => {
// if let Some(buffer) = &mut self.notifications {
// let notification: Notification = message.decode()?;
// let _ = self.write_packet(notification);
// continue;
// }
continue;
}
MessageType::ParameterStatus => {
// informs the frontend about the current (initial)
// setting of backend parameters
// we currently have no use for that data so we promptly ignore this message
continue;
}
MessageType::NoticeResponse => {
// do we need this to be more configurable?
// if you are reading this comment and think so, open an issue
let notice: Notice = message.decode()?;
let lvl = match notice.severity() {
PgSeverity::Fatal | PgSeverity::Panic | PgSeverity::Error => Level::Error,
PgSeverity::Warning => Level::Warn,
PgSeverity::Notice => Level::Info,
PgSeverity::Debug => Level::Debug,
PgSeverity::Info => Level::Trace,
PgSeverity::Log => Level::Trace,
};
if lvl <= log::STATIC_MAX_LEVEL && lvl <= log::max_level() {
log::logger().log(
&log::Record::builder()
.args(format_args!("{}", notice.message()))
.level(lvl)
.module_path_static(Some("sqlx::postgres::notice"))
.file_static(Some(file!()))
.line(Some(line!()))
.build(),
);
}
continue;
}
_ => {}
}
return T::deserialize_with(message.contents, ());
}
}};
(@blocking @stream $self:ident, $offset:expr, $n:expr) => {
$self.stream.read($offset, $n)?;
};
(@stream $self:ident, $offset:expr, $n:expr) => {
$self.stream.read_async($offset, $n).await?;
};
}
impl<Rt> PostgresConnection<Rt>
where
Rt: Runtime,
{
#[cfg(feature = "async")]
pub(super) async fn read_packet_async<'de, T>(&'de mut self) -> Result<T>
where
T: Deserialize<'de, ()> + Debug,
Rt: sqlx_core::Async,
{
read_packet!(self)
}
#[cfg(feature = "blocking")]
pub(super) fn read_packet<'de, T>(&'de mut self) -> Result<T>
where
T: Deserialize<'de, ()> + Debug,
Rt: sqlx_core::blocking::Runtime,
{
read_packet!(@blocking self)
}
}

View File

@ -1,15 +1,30 @@
use sqlx_core::{Database, HasOutput, Runtime};
use sqlx_core::database::{HasOutput, HasRawValue};
use sqlx_core::Database;
use super::{PgColumn, PgOutput, PgQueryResult, PgRawValue, PgRow, PgTypeId, PgTypeInfo};
#[derive(Debug)]
pub struct Postgres;
impl<Rt> Database<Rt> for Postgres
where
Rt: Runtime,
{
type Connection = super::PostgresConnection<Rt>;
impl Database for Postgres {
type Column = PgColumn;
type Row = PgRow;
type QueryResult = PgQueryResult;
type TypeInfo = PgTypeInfo;
type TypeId = PgTypeId;
}
// 'x: execution
impl<'x> HasOutput<'x> for Postgres {
type Output = &'x mut Vec<u8>;
type Output = PgOutput<'x>;
}
// 'r: row
impl<'r> HasRawValue<'r> for Postgres {
type Database = Self;
type RawValue = PgRawValue<'r>;
}

View File

@ -1,41 +0,0 @@
use std::error::Error as StdError;
use std::fmt::{self, Display, Formatter};
use sqlx_core::DatabaseError;
/// An error returned from the PostgreSQL database server.
#[allow(clippy::module_name_repetitions)]
#[derive(Debug)]
pub struct PostgresDatabaseError(String);
impl PostgresDatabaseError {
pub(crate) fn protocol(msg: String) -> PostgresDatabaseError {
PostgresDatabaseError(msg)
}
}
impl From<crypto_mac::InvalidKeyLength> for PostgresDatabaseError {
fn from(err: crypto_mac::InvalidKeyLength) -> Self {
PostgresDatabaseError::protocol(err.to_string())
}
}
impl From<crypto_mac::MacError> for PostgresDatabaseError {
fn from(err: crypto_mac::MacError) -> Self {
PostgresDatabaseError::protocol(err.to_string())
}
}
impl DatabaseError for PostgresDatabaseError {
fn message(&self) -> &str {
&self.0
}
}
impl Display for PostgresDatabaseError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl StdError for PostgresDatabaseError {}

View File

@ -1,3 +0,0 @@
mod write;
pub(crate) use write::PgBufMutExt;

View File

@ -1,52 +0,0 @@
pub trait PgBufMutExt {
fn write_length_prefixed<F>(&mut self, f: F)
where
F: FnOnce(&mut Vec<u8>);
fn write_statement_name(&mut self, id: u32);
fn write_portal_name(&mut self, id: Option<u32>);
}
impl PgBufMutExt for Vec<u8> {
// writes a length-prefixed message, this is used when encoding nearly all messages as postgres
// wants us to send the length of the often-variable-sized messages up front
fn write_length_prefixed<F>(&mut self, f: F)
where
F: FnOnce(&mut Vec<u8>),
{
// reserve space to write the prefixed length
let offset = self.len();
self.extend(&[0; 4]);
// write the main body of the message
f(self);
// now calculate the size of what we wrote and set the length value
let size = (self.len() - offset) as i32;
self[offset..(offset + 4)].copy_from_slice(&size.to_be_bytes());
}
// writes a statement name by ID
#[inline]
fn write_statement_name(&mut self, id: u32) {
// N.B. if you change this don't forget to update it in ../describe.rs
self.extend(b"sqlx_s_");
itoa::write(&mut *self, id).unwrap();
self.push(0);
}
// writes a portal name by ID
#[inline]
fn write_portal_name(&mut self, id: Option<u32>) {
if let Some(id) = id {
self.extend(b"sqlx_p_");
itoa::write(&mut *self, id).unwrap();
}
self.push(0);
}
}

View File

@ -1,6 +1,6 @@
//! [PostgreSQL] database driver.
//!
//! [PostgreSQL]: https://www.postgres.com/
//! [PostgreSQL]: https://www.postgresql.org/
//!
#![cfg_attr(doc_cfg, feature(doc_cfg))]
#![cfg_attr(not(any(feature = "async", feature = "blocking")), allow(unused))]
@ -17,18 +17,45 @@
#![warn(clippy::use_self)]
#![warn(clippy::useless_let_if_seq)]
#![allow(clippy::doc_markdown)]
#![allow(clippy::missing_errors_doc)]
#![allow(clippy::missing_panics_doc)]
mod connection;
use sqlx_core::Arguments;
#[macro_use]
mod stream;
mod column;
// mod connection;
mod database;
mod error;
mod io;
mod options;
// mod error;
// mod io;
// mod options;
mod output;
mod protocol;
mod query_result;
// mod raw_statement;
mod raw_value;
mod row;
// mod transaction;
mod type_id;
mod type_info;
pub mod types;
#[cfg(test)]
mod mock;
// #[cfg(test)]
// mod mock;
pub use connection::PostgresConnection;
pub use column::PgColumn;
// pub use connection::PgConnection;
pub use database::Postgres;
pub use error::PostgresDatabaseError;
pub use options::PostgresConnectOptions;
// pub use error::PgDatabaseError;
// pub use options::PgConnectOptions;
pub use output::PgOutput;
pub use query_result::PgQueryResult;
pub use raw_value::{PgRawValue, PgRawValueFormat};
pub use row::PgRow;
pub use type_id::PgTypeId;
pub use type_info::PgTypeInfo;
// 'a: argument values
pub type PgArguments<'a> = Arguments<'a, Postgres>;

View File

@ -1,103 +0,0 @@
use std::fmt::{self, Debug, Formatter};
use std::marker::PhantomData;
use std::path::PathBuf;
use either::Either;
use sqlx_core::{ConnectOptions, Runtime};
use crate::PostgresConnection;
mod builder;
mod default;
mod getters;
mod parse;
// TODO: RSA Public Key (to avoid the key exchange for caching_sha2 and sha256 plugins)
/// Options which can be used to configure how a PostgreSQL connection is opened.
///
#[allow(clippy::module_name_repetitions)]
pub struct PostgresConnectOptions<Rt>
where
Rt: Runtime,
{
runtime: PhantomData<Rt>,
pub(crate) address: Either<(String, u16), PathBuf>,
username: Option<String>,
password: Option<String>,
database: Option<String>,
timezone: String,
charset: String,
}
impl<Rt> Clone for PostgresConnectOptions<Rt>
where
Rt: Runtime,
{
fn clone(&self) -> Self {
Self {
runtime: PhantomData,
address: self.address.clone(),
username: self.username.clone(),
password: self.password.clone(),
database: self.database.clone(),
timezone: self.timezone.clone(),
charset: self.charset.clone(),
}
}
}
impl<Rt> Debug for PostgresConnectOptions<Rt>
where
Rt: Runtime,
{
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("PostgresConnectOptions")
.field(
"address",
&self
.address
.as_ref()
.map_left(|(host, port)| format!("{}:{}", host, port))
.map_right(|socket| socket.display()),
)
.field("username", &self.username)
.field("password", &self.password)
.field("database", &self.database)
.field("timezone", &self.timezone)
.field("charset", &self.charset)
.finish()
}
}
impl<Rt> ConnectOptions<Rt> for PostgresConnectOptions<Rt>
where
Rt: Runtime,
{
type Connection = PostgresConnection<Rt>;
#[cfg(feature = "async")]
fn connect(&self) -> futures_util::future::BoxFuture<'_, sqlx_core::Result<Self::Connection>>
where
Self::Connection: Sized,
Rt: sqlx_core::Async,
{
Box::pin(PostgresConnection::<Rt>::connect_async(self))
}
}
#[cfg(feature = "blocking")]
mod blocking {
use sqlx_core::blocking::{ConnectOptions, Runtime};
use super::{PostgresConnectOptions, PostgresConnection};
impl<Rt: Runtime> ConnectOptions<Rt> for PostgresConnectOptions<Rt> {
fn connect(&self) -> sqlx_core::Result<Self::Connection>
where
Self::Connection: Sized,
{
<PostgresConnection<Rt>>::connect(self)
}
}
}

View File

@ -1,82 +0,0 @@
use std::mem;
use std::path::{Path, PathBuf};
use either::Either;
use sqlx_core::Runtime;
impl<Rt> super::PostgresConnectOptions<Rt>
where
Rt: Runtime,
{
/// Sets the hostname of the database server.
///
/// If the hostname begins with a slash (`/`), it is interpreted as the absolute path
/// to a Unix domain socket file instead of a hostname of a server.
///
/// Defaults to `localhost`.
///
pub fn host(&mut self, host: impl AsRef<str>) -> &mut Self {
let host = host.as_ref();
self.address = if host.starts_with('/') {
Either::Right(PathBuf::from(&*host))
} else {
Either::Left((host.into(), self.get_port()))
};
self
}
/// Sets the path of the Unix domain socket to connect to.
///
/// Overrides [`host()`](#method.host) and [`port()`](#method.port).
///
pub fn socket(&mut self, socket: impl AsRef<Path>) -> &mut Self {
self.address = Either::Right(socket.as_ref().to_owned());
self
}
/// Sets the TCP port number of the database server.
///
/// Defaults to `3306`.
///
pub fn port(&mut self, port: u16) -> &mut Self {
self.address = match self.address {
Either::Right(_) => Either::Left(("localhost".to_owned(), port)),
Either::Left((ref mut host, _)) => Either::Left((mem::take(host), port)),
};
self
}
/// Sets the username to be used for authentication.
// FIXME: Specify what happens when you do NOT set this
pub fn username(&mut self, username: impl AsRef<str>) -> &mut Self {
self.username = Some(username.as_ref().to_owned());
self
}
/// Sets the password to be used for authentication.
pub fn password(&mut self, password: impl AsRef<str>) -> &mut Self {
self.password = Some(password.as_ref().to_owned());
self
}
/// Sets the default database for the connection.
pub fn database(&mut self, database: impl AsRef<str>) -> &mut Self {
self.database = Some(database.as_ref().to_owned());
self
}
/// Sets the character set for the connection.
pub fn charset(&mut self, charset: impl AsRef<str>) -> &mut Self {
self.charset = charset.as_ref().to_owned();
self
}
/// Sets the timezone for the connection.
pub fn timezone(&mut self, timezone: impl AsRef<str>) -> &mut Self {
self.timezone = timezone.as_ref().to_owned();
self
}
}

View File

@ -1,38 +0,0 @@
use std::marker::PhantomData;
use either::Either;
use sqlx_core::Runtime;
use crate::PostgresConnectOptions;
pub(crate) const HOST: &str = "localhost";
pub(crate) const PORT: u16 = 3306;
impl<Rt> Default for PostgresConnectOptions<Rt>
where
Rt: Runtime,
{
fn default() -> Self {
Self {
runtime: PhantomData,
address: Either::Left((HOST.to_owned(), PORT)),
username: None,
password: None,
database: None,
charset: "utf8mb4".to_owned(),
timezone: "utc".to_owned(),
// todo: connect_timeout
}
}
}
impl<Rt> super::PostgresConnectOptions<Rt>
where
Rt: Runtime,
{
/// Creates a default set of options ready for configuration.
#[must_use]
pub fn new() -> Self {
Self::default()
}
}

View File

@ -1,55 +0,0 @@
use std::path::{Path, PathBuf};
use sqlx_core::Runtime;
use super::{default, PostgresConnectOptions};
impl<Rt: Runtime> PostgresConnectOptions<Rt> {
/// Returns the hostname of the database server.
#[must_use]
pub fn get_host(&self) -> &str {
self.address.as_ref().left().map_or(default::HOST, |(host, _)| &**host)
}
/// Returns the TCP port number of the database server.
#[must_use]
pub fn get_port(&self) -> u16 {
self.address.as_ref().left().map_or(default::PORT, |(_, port)| *port)
}
/// Returns the path to the Unix domain socket, if one is configured.
#[must_use]
pub fn get_socket(&self) -> Option<&Path> {
self.address.as_ref().right().map(PathBuf::as_path)
}
/// Returns the default database name.
#[must_use]
pub fn get_database(&self) -> Option<&str> {
self.database.as_deref()
}
/// Returns the username to be used for authentication.
#[must_use]
pub fn get_username(&self) -> Option<&str> {
self.username.as_deref()
}
/// Returns the password to be used for authentication.
#[must_use]
pub fn get_password(&self) -> Option<&str> {
self.password.as_deref()
}
/// Returns the character set for the connection.
#[must_use]
pub fn get_charset(&self) -> &str {
&self.charset
}
/// Returns the timezone for the connection.
#[must_use]
pub fn get_timezone(&self) -> &str {
&self.timezone
}
}

View File

@ -1,183 +0,0 @@
use std::borrow::Cow;
use std::str::FromStr;
use percent_encoding::percent_decode_str;
use sqlx_core::{Error, Runtime};
use url::Url;
use crate::PostgresConnectOptions;
impl<Rt> FromStr for PostgresConnectOptions<Rt>
where
Rt: Runtime,
{
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let url: Url =
s.parse().map_err(|error| Error::configuration("for database URL", error))?;
if !matches!(url.scheme(), "postgres") {
return Err(Error::configuration_msg(format!(
"unsupported URL scheme {:?} for MySQL",
url.scheme()
)));
}
let mut options = Self::new();
if let Some(host) = url.host_str() {
options.host(percent_decode_str_utf8(host));
}
if let Some(port) = url.port() {
options.port(port);
}
let username = url.username();
if !username.is_empty() {
options.username(percent_decode_str_utf8(username));
}
if let Some(password) = url.password() {
options.password(percent_decode_str_utf8(password));
}
let mut path = url.path();
if path.starts_with('/') {
path = &path[1..];
}
if !path.is_empty() {
options.database(path);
}
for (key, value) in url.query_pairs() {
match &*key {
"user" | "username" => {
options.username(value);
}
"password" => {
options.password(value);
}
// ssl-mode compatibly with SQLx <= 0.5
// sslmode compatibly with PostgreSQL
// sslMode compatibly with JDBC MySQL
// tls compatibly with Go MySQL [preferred]
"ssl-mode" | "sslmode" | "sslMode" | "tls" => {
todo!()
}
"charset" => {
options.charset(value);
}
"timezone" => {
options.timezone(value);
}
"socket" => {
options.socket(&*value);
}
_ => {
// ignore unknown connection parameters
// fixme: should we error or warn here?
}
}
}
Ok(options)
}
}
// todo: this should probably go somewhere common
fn percent_decode_str_utf8(value: &str) -> Cow<'_, str> {
percent_decode_str(value).decode_utf8_lossy()
}
#[cfg(test)]
mod tests {
use std::path::Path;
use sqlx_core::mock::Mock;
use super::PostgresConnectOptions;
#[test]
fn parse() {
let url = "postgres://user:password@hostname:5432/database?timezone=system&charset=utf8";
let options: PostgresConnectOptions<Mock> = url.parse().unwrap();
assert_eq!(options.get_username(), Some("user"));
assert_eq!(options.get_password(), Some("password"));
assert_eq!(options.get_host(), "hostname");
assert_eq!(options.get_port(), 5432);
assert_eq!(options.get_database(), Some("database"));
assert_eq!(options.get_timezone(), "system");
assert_eq!(options.get_charset(), "utf8");
}
#[test]
fn parse_with_defaults() {
let url = "postgres://";
let options: PostgresConnectOptions<Mock> = url.parse().unwrap();
assert_eq!(options.get_username(), None);
assert_eq!(options.get_password(), None);
assert_eq!(options.get_host(), "localhost");
assert_eq!(options.get_port(), 3306);
assert_eq!(options.get_database(), None);
assert_eq!(options.get_timezone(), "utc");
assert_eq!(options.get_charset(), "utf8mb4");
}
#[test]
fn parse_socket_from_query() {
let url = "postgres://user:password@localhost/database?socket=/var/run/postgresd/postgresd.sock";
let options: PostgresConnectOptions<Mock> = url.parse().unwrap();
assert_eq!(options.get_username(), Some("user"));
assert_eq!(options.get_password(), Some("password"));
assert_eq!(options.get_database(), Some("database"));
assert_eq!(options.get_socket(), Some(Path::new("/var/run/postgresd/postgresd.sock")));
}
#[test]
fn parse_socket_from_host() {
// socket path in host requires URL encoding - but does work
let url = "postgres://user:password@%2Fvar%2Frun%2Fpostgresd%2Fpostgresd.sock/database";
let options: PostgresConnectOptions<Mock> = url.parse().unwrap();
assert_eq!(options.get_username(), Some("user"));
assert_eq!(options.get_password(), Some("password"));
assert_eq!(options.get_database(), Some("database"));
assert_eq!(options.get_socket(), Some(Path::new("/var/run/postgresd/postgresd.sock")));
}
#[test]
#[should_panic]
fn fail_to_parse_non_postgres() {
let url = "postgres://user:password@hostname:5432/database?timezone=system&charset=utf8";
let _: PostgresConnectOptions<Mock> = url.parse().unwrap();
}
#[test]
fn parse_username_with_at_sign() {
let url = "postgres://user@hostname:password@hostname:5432/database";
let options: PostgresConnectOptions<Mock> = url.parse().unwrap();
assert_eq!(options.get_username(), Some("user@hostname"));
}
#[test]
fn parse_password_with_non_ascii_chars() {
let url = "postgres://username:p@ssw0rd@hostname:5432/database";
let options: PostgresConnectOptions<Mock> = url.parse().unwrap();
assert_eq!(options.get_password(), Some("p@ssw0rd"));
}
}

View File

@ -0,0 +1,15 @@
// 'x: execution
#[allow(clippy::module_name_repetitions)]
pub struct PgOutput<'x> {
buffer: &'x mut Vec<u8>,
}
impl<'x> PgOutput<'x> {
pub(crate) fn new(buffer: &'x mut Vec<u8>) -> Self {
Self { buffer }
}
pub(crate) fn buffer(&mut self) -> &mut Vec<u8> {
self.buffer
}
}

View File

@ -1,111 +1,3 @@
use std::convert::TryFrom;
mod message;
use bytes::{Buf, Bytes};
use sqlx_core::io::Deserialize;
use sqlx_core::Error;
use sqlx_core::Result;
mod authentication;
mod backend_key_data;
mod close;
mod notification;
mod password;
mod ready_for_query;
mod response;
mod sasl;
mod startup;
mod terminate;
pub(crate) use authentication::{Authentication, AuthenticationSasl};
pub(crate) use backend_key_data::BackendKeyData;
pub(crate) use close::Close;
pub(crate) use notification::Notification;
pub(crate) use password::Password;
pub(crate) use ready_for_query::ReadyForQuery;
pub(crate) use response::{Notice, PgSeverity};
pub(crate) use sasl::{SaslInitialResponse, SaslResponse};
pub(crate) use startup::Startup;
pub(crate) use terminate::Terminate;
#[derive(Debug, Clone, Copy, PartialOrd, PartialEq, Eq)]
#[repr(u8)]
pub enum MessageType {
ParseComplete = b'1',
BindComplete = b'2',
CloseComplete = b'3',
CommandComplete = b'C',
DataRow = b'D',
ErrorResponse = b'E',
EmptyQueryResponse = b'I',
NotificationResponse = b'A',
BackendKeyData = b'K',
NoticeResponse = b'N',
Authentication = b'R',
ParameterStatus = b'S',
RowDescription = b'T',
ReadyForQuery = b'Z',
NoData = b'n',
PortalSuspended = b's',
ParameterDescription = b't',
}
#[derive(Debug)]
pub struct Message {
pub r#type: MessageType,
pub contents: Bytes,
}
impl Message {
#[inline]
pub fn decode<'de, T>(self) -> Result<T>
where
T: Deserialize<'de, ()>,
{
T::deserialize_with(self.contents, ())
}
}
impl TryFrom<u8> for MessageType {
type Error = Error;
fn try_from(v: u8) -> Result<Self> {
// https://www.postgresql.org/docs/current/protocol-message-formats.html
Ok(match v {
b'1' => MessageType::ParseComplete,
b'2' => MessageType::BindComplete,
b'3' => MessageType::CloseComplete,
b'C' => MessageType::CommandComplete,
b'D' => MessageType::DataRow,
b'E' => MessageType::ErrorResponse,
b'I' => MessageType::EmptyQueryResponse,
b'A' => MessageType::NotificationResponse,
b'K' => MessageType::BackendKeyData,
b'N' => MessageType::NoticeResponse,
b'R' => MessageType::Authentication,
b'S' => MessageType::ParameterStatus,
b'T' => MessageType::RowDescription,
b'Z' => MessageType::ReadyForQuery,
b'n' => MessageType::NoData,
b's' => MessageType::PortalSuspended,
b't' => MessageType::ParameterDescription,
_ => {
return Err(Error::configuration_msg(format!(
"unknown message type: {:?}",
v as char
)));
}
})
}
}
impl Deserialize<'_, ()> for Message {
fn deserialize_with(mut buf: Bytes, _: ()) -> Result<Self> {
let r#type = MessageType::try_from(buf.get_u8())?;
let size = buf.get_u32() - 4;
let contents = buf.split_to(size as usize);
Ok(Message { r#type, contents })
}
}
pub(crate) use message::{BackendMessage, BackendMessageType};

View File

@ -1,204 +0,0 @@
use bytes::{Buf, Bytes};
use memchr::memchr;
use sqlx_core::io::Deserialize;
use sqlx_core::Error;
use sqlx_core::Result;
// On startup, the server sends an appropriate authentication request message,
// to which the frontend must reply with an appropriate authentication
// response message (such as a password).
// For all authentication methods except GSSAPI, SSPI and SASL, there is at
// most one request and one response. In some methods, no response at all is
// needed from the frontend, and so no authentication request occurs.
// For GSSAPI, SSPI and SASL, multiple exchanges of packets may
// be needed to complete the authentication.
// <https://www.postgresql.org/docs/devel/protocol-flow.html#id-1.10.5.7.3>
// <https://www.postgresql.org/docs/devel/protocol-message-formats.html>
#[derive(Debug)]
pub enum Authentication {
/// The authentication exchange is successfully completed.
Ok,
/// The frontend must now send a [PasswordMessage] containing the
/// password in clear-text form.
CleartextPassword,
/// The frontend must now send a [PasswordMessage] containing the
/// password (with user name) encrypted via MD5, then encrypted
/// again using the 4-byte random salt.
Md5Password(AuthenticationMd5Password),
/// The frontend must now initiate a SASL negotiation,
/// using one of the SASL mechanisms listed in the message.
///
/// The frontend will send a [SaslInitialResponse] with the name
/// of the selected mechanism, and the first part of the SASL
/// data stream in response to this.
///
/// If further messages are needed, the server will
/// respond with [Authentication::SaslContinue].
Sasl(AuthenticationSasl),
/// This message contains challenge data from the previous step of SASL negotiation.
///
/// The frontend must respond with a [SaslResponse] message.
SaslContinue(AuthenticationSaslContinue),
/// SASL authentication has completed with additional mechanism-specific
/// data for the client.
///
/// The server will next send [Authentication::Ok] to
/// indicate successful authentication.
SaslFinal(AuthenticationSaslFinal),
}
impl Deserialize<'_, ()> for Authentication {
fn deserialize_with(mut buf: Bytes, _: ()) -> Result<Self> {
Ok(match buf.get_u32() {
0 => Authentication::Ok,
3 => Authentication::CleartextPassword,
5 => {
let mut salt = [0; 4];
buf.copy_to_slice(&mut salt);
Authentication::Md5Password(AuthenticationMd5Password { salt })
}
10 => Authentication::Sasl(AuthenticationSasl(buf)),
11 => {
Authentication::SaslContinue(AuthenticationSaslContinue::deserialize_with(buf, ())?)
}
12 => Authentication::SaslFinal(AuthenticationSaslFinal::deserialize_with(buf, ())?),
ty => {
return Err(Error::configuration_msg(format!(
"unknown authentication method: {}",
ty
)));
}
})
}
}
/// Body of [Authentication::Md5Password].
#[derive(Debug)]
pub struct AuthenticationMd5Password {
pub salt: [u8; 4],
}
/// Body of [Authentication::Sasl].
#[derive(Debug)]
pub struct AuthenticationSasl(Bytes);
impl AuthenticationSasl {
#[inline]
pub fn mechanisms(&self) -> SaslMechanisms<'_> {
SaslMechanisms(&self.0)
}
}
/// An iterator over the SASL authentication mechanisms provided by the server.
pub struct SaslMechanisms<'a>(&'a [u8]);
impl<'a> Iterator for SaslMechanisms<'a> {
type Item = &'a str;
fn next(&mut self) -> Option<Self::Item> {
if !self.0.is_empty() && self.0[0] == b'\0' {
return None;
}
#[allow(unsafe_code)]
let mechanism = memchr(b'\0', self.0)
// UNSAFE: Postgres is expecte to return a valid UTF-8 string here
.and_then(|nul| Some(unsafe { std::str::from_utf8_unchecked(&self.0[..nul]) }))?;
self.0 = &self.0[(mechanism.len() + 1)..];
Some(mechanism)
}
}
#[derive(Debug)]
pub struct AuthenticationSaslContinue {
pub salt: Vec<u8>,
pub iterations: u32,
pub nonce: String,
pub message: String,
}
impl Deserialize<'_, ()> for AuthenticationSaslContinue {
fn deserialize_with(mut buf: Bytes, _: ()) -> Result<Self> {
let mut iterations: u32 = 4096;
let mut salt = Vec::new();
let mut nonce = Bytes::new();
// [Example]
// r=/z+giZiTxAH7r8sNAeHr7cvpqV3uo7G/bJBIJO3pjVM7t3ng,s=4UV68bIkC8f9/X8xH7aPhg==,i=4096
for item in buf.split(|b| *b == b',') {
let key = item[0];
let value = &item[2..];
match key {
b'r' => {
nonce = buf.slice_ref(value);
}
b'i' => {
iterations = atoi::atoi(value).unwrap_or(4096);
}
b's' => {
// TODO: Map error correctly
salt = base64::decode(value).unwrap();
}
_ => {}
}
}
#[allow(unsafe_code)]
Ok(Self {
iterations,
salt,
// UNSAFE: Postgres is expected to return a valid UTF-8 string here
nonce: unsafe { String::from_utf8_unchecked((*nonce).to_vec()) },
// UNSAFE: Postgres is expected to return a valid UTF-8 string here
message: unsafe { String::from_utf8_unchecked((*buf).to_vec()) },
})
}
}
#[derive(Debug)]
pub struct AuthenticationSaslFinal {
pub verifier: Vec<u8>,
}
impl Deserialize<'_, ()> for AuthenticationSaslFinal {
fn deserialize_with(mut buf: Bytes, _: ()) -> Result<Self> {
let mut verifier = Vec::new();
for item in buf.split(|b| *b == b',') {
let key = item[0];
let value = &item[2..];
if let b'v' = key {
// TODO: Map error correctly
verifier = base64::decode(value).unwrap();
}
}
Ok(Self { verifier })
}
}

View File

@ -1,25 +0,0 @@
use byteorder::{BigEndian, ByteOrder};
use bytes::Bytes;
use sqlx_core::io::Deserialize;
use sqlx_core::Error;
use sqlx_core::Result;
/// Contains cancellation key data. The frontend must save these values if it
/// wishes to be able to issue `CancelRequest` messages later.
#[derive(Debug)]
pub struct BackendKeyData {
/// The process ID of this database.
pub process_id: u32,
/// The secret key of this database.
pub secret_key: u32,
}
impl Deserialize<'_, ()> for BackendKeyData {
fn deserialize_with(mut buf: Bytes, _: ()) -> Result<Self> {
let process_id = BigEndian::read_u32(&buf);
let secret_key = BigEndian::read_u32(&buf[4..]);
Ok(Self { process_id, secret_key })
}
}

View File

@ -1,36 +0,0 @@
use sqlx_core::io::Serialize;
use sqlx_core::Result;
use crate::io::PgBufMutExt;
const CLOSE_PORTAL: u8 = b'P';
const CLOSE_STATEMENT: u8 = b'S';
#[derive(Debug)]
#[allow(dead_code)]
pub enum Close {
Statement(u32),
Portal(u32),
}
impl Serialize<'_, ()> for Close {
fn serialize_with(&self, buf: &mut Vec<u8>, _: ()) -> Result<()> {
// 15 bytes for 1-digit statement/portal IDs
buf.reserve(20);
buf.push(b'C');
buf.write_length_prefixed(|buf| match self {
Close::Statement(id) => {
buf.push(CLOSE_STATEMENT);
buf.write_statement_name(*id);
}
Close::Portal(id) => {
buf.push(CLOSE_PORTAL);
buf.write_portal_name(Some(*id));
}
});
Ok(())
}
}

View File

@ -1,18 +0,0 @@
use sqlx_core::io::Serialize;
use sqlx_core::Result;
// The Flush message does not cause any specific output to be generated,
// but forces the backend to deliver any data pending in its output buffers.
// A Flush must be sent after any extended-query command except Sync, if the
// frontend wishes to examine the results of that command before issuing more commands.
#[derive(Debug)]
pub struct Flush;
impl Serialize<'_, ()> for Flush {
fn serialize_with(&self, buf: &mut Vec<u8>, _: ()) -> Result<()> {
buf.push(b'H');
buf.extend(&4_i32.to_be_bytes());
}
}

View File

@ -0,0 +1,93 @@
use std::convert::TryFrom;
use std::fmt::Debug;
use bytes::Bytes;
use sqlx_core::io::Deserialize;
use sqlx_core::{Error, Result};
/// Type of the *incoming* message.
///
/// Postgres does use the same message format for client and server messages but we are only
/// interested in messages from the backend.
///
#[derive(Debug, Clone, Copy)]
#[repr(u8)]
pub(crate) enum BackendMessageType {
ParseComplete = b'1',
BindComplete = b'2',
CloseComplete = b'3',
CommandComplete = b'C',
DataRow = b'D',
ErrorResponse = b'E',
EmptyQueryResponse = b'I',
NotificationResponse = b'A',
BackendKeyData = b'K',
NoticeResponse = b'N',
Authentication = b'R',
ParameterStatus = b'S',
RowDescription = b'T',
ReadyForQuery = b'Z',
NoData = b'n',
PortalSuspended = b's',
ParameterDescription = b't',
CopyInResponse = b'G',
CopyOutResponse = b'H',
CopyBothResponse = b'W',
CopyData = b'd',
CopyDone = b'c',
}
impl TryFrom<u8> for BackendMessageType {
type Error = Error;
fn try_from(ty: u8) -> Result<Self> {
Ok(match ty {
b'1' => Self::ParseComplete,
b'2' => Self::BindComplete,
b'3' => Self::CloseComplete,
b'C' => Self::CommandComplete,
b'D' => Self::DataRow,
b'E' => Self::ErrorResponse,
b'I' => Self::EmptyQueryResponse,
b'A' => Self::NotificationResponse,
b'K' => Self::BackendKeyData,
b'N' => Self::NoticeResponse,
b'R' => Self::Authentication,
b'S' => Self::ParameterStatus,
b'T' => Self::RowDescription,
b'Z' => Self::ReadyForQuery,
b'n' => Self::NoData,
b's' => Self::PortalSuspended,
b't' => Self::ParameterDescription,
b'G' => Self::CopyInResponse,
b'H' => Self::CopyOutResponse,
b'W' => Self::CopyBothResponse,
b'd' => Self::CopyData,
b'c' => Self::CopyDone,
_ => {
todo!("protocol unexpected data error")
}
})
}
}
#[derive(Debug)]
pub(crate) struct BackendMessage {
pub(crate) r#type: BackendMessageType,
pub(crate) contents: Bytes,
}
impl BackendMessage {
#[inline]
pub(crate) fn deserialize<'de, T>(self) -> Result<T>
where
T: Deserialize<'de> + Debug,
{
let packet = T::deserialize(self.contents)?;
log::trace!("read > {:?}", packet);
Ok(packet)
}
}

View File

@ -1,28 +0,0 @@
use bytes::{Buf, Bytes};
use bytestring::ByteString;
use sqlx_core::io::BufExt;
use sqlx_core::io::Deserialize;
use sqlx_core::Result;
#[derive(Debug)]
pub struct Notification {
pub(crate) process_id: u32,
pub(crate) channel: ByteString,
pub(crate) payload: ByteString,
}
impl Deserialize<'_, ()> for Notification {
fn deserialize_with(mut buf: Bytes, _: ()) -> Result<Self> {
let process_id = buf.get_u32();
// UNSAFE: This message will not be read.
#[allow(unsafe_code)]
let channel = unsafe { buf.get_str_nul_unchecked()? };
// UNSAFE: This message will not be read.
#[allow(unsafe_code)]
let payload = unsafe { buf.get_str_nul_unchecked()? };
Ok(Self { process_id, channel, payload })
}
}

View File

@ -1,67 +0,0 @@
use std::fmt::Write;
use md5::{Digest, Md5};
use sqlx_core::io::Serialize;
use sqlx_core::io::WriteExt;
use sqlx_core::Result;
use crate::io::PgBufMutExt;
#[derive(Debug)]
pub enum Password<'a> {
Cleartext(&'a str),
Md5 { password: &'a str, username: &'a str, salt: [u8; 4] },
}
impl Password<'_> {
#[inline]
fn len(&self) -> usize {
match self {
Password::Cleartext(s) => s.len() + 5,
Password::Md5 { .. } => 35 + 5,
}
}
}
impl Serialize<'_, ()> for Password<'_> {
fn serialize_with(&self, buf: &mut Vec<u8>, _: ()) -> Result<()> {
buf.reserve(1 + 4 + self.len());
buf.push(b'p');
buf.write_length_prefixed(|buf| {
match self {
Password::Cleartext(password) => {
buf.write_str_nul(password);
}
Password::Md5 { username, password, salt } => {
// The actual `PasswordMessage` can be comwriteed in SQL as
// `concat('md5', md5(concat(md5(concat(password, username)), random-salt)))`.
// Keep in mind the md5() function returns its result as a hex string.
let mut hasher = Md5::new();
hasher.update(password);
hasher.update(username);
let mut output = String::with_capacity(35);
let _ = write!(output, "{:x}", hasher.finalize_reset());
hasher.update(&output);
hasher.update(salt);
output.clear();
let _ = write!(output, "md5{:x}", hasher.finalize());
buf.write_str_nul(&output);
}
}
});
Ok(())
}
}

View File

@ -1,51 +0,0 @@
use std::convert::TryFrom;
use bytes::Bytes;
use sqlx_core::io::Deserialize;
use sqlx_core::Error;
use sqlx_core::Result;
#[derive(Debug)]
#[repr(u8)]
pub(crate) enum TransactionStatus {
/// Not in a transaction block.
Idle = b'I',
/// In a transaction block.
Transaction = b'T',
/// In a _failed_ transaction block. Queries will be rejected until block is ended.
Error = b'E',
}
impl TryFrom<u8> for TransactionStatus {
type Error = Error;
fn try_from(value: u8) -> Result<Self> {
match value {
b'I' => Ok(TransactionStatus::Idle),
b'T' => Ok(TransactionStatus::Transaction),
b'E' => Ok(TransactionStatus::Error),
status => {
return Err(Error::configuration_msg(format!(
"unknown transaction status: {:?}",
status as char,
)));
}
}
}
}
#[derive(Debug)]
pub(crate) struct ReadyForQuery {
pub transaction_status: TransactionStatus,
}
impl Deserialize<'_, ()> for ReadyForQuery {
fn deserialize_with(mut buf: Bytes, _: ()) -> Result<Self> {
let transaction_status = TransactionStatus::try_from(buf[0])?;
Ok(Self { transaction_status })
}
}

View File

@ -1,190 +0,0 @@
use std::str::from_utf8;
use bytes::Bytes;
use memchr::memchr;
use sqlx_core::io::Deserialize;
use sqlx_core::Error;
use sqlx_core::Result;
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
#[repr(u8)]
pub enum PgSeverity {
Panic,
Fatal,
Error,
Warning,
Notice,
Debug,
Info,
Log,
}
impl PgSeverity {
#[inline]
pub fn is_error(self) -> bool {
matches!(self, Self::Panic | Self::Fatal | Self::Error)
}
}
impl std::convert::TryFrom<&str> for PgSeverity {
type Error = Error;
fn try_from(s: &str) -> Result<PgSeverity> {
let result = match s {
"PANIC" => PgSeverity::Panic,
"FATAL" => PgSeverity::Fatal,
"ERROR" => PgSeverity::Error,
"WARNING" => PgSeverity::Warning,
"NOTICE" => PgSeverity::Notice,
"DEBUG" => PgSeverity::Debug,
"INFO" => PgSeverity::Info,
"LOG" => PgSeverity::Log,
severity => {
return Err(Error::configuration_msg(format!("unknown severity: {:?}", severity)));
}
};
Ok(result)
}
}
#[derive(Debug)]
pub struct Notice {
storage: Bytes,
severity: PgSeverity,
message: (u16, u16),
code: (u16, u16),
}
impl Notice {
#[inline]
pub fn severity(&self) -> PgSeverity {
self.severity
}
#[inline]
pub fn code(&self) -> &str {
self.get_cached_str(self.code)
}
#[inline]
pub fn message(&self) -> &str {
self.get_cached_str(self.message)
}
// Field descriptions available here:
// https://www.postgresql.org/docs/current/protocol-error-fields.html
#[inline]
pub fn get(&self, ty: u8) -> Option<&str> {
self.get_raw(ty).and_then(|v| from_utf8(v).ok())
}
pub fn get_raw(&self, ty: u8) -> Option<&[u8]> {
self.fields()
.filter(|(field, _)| *field == ty)
.map(|(_, (start, end))| &self.storage[start as usize..end as usize])
.next()
}
}
impl Notice {
#[inline]
fn fields(&self) -> Fields<'_> {
Fields { storage: &self.storage, offset: 0 }
}
#[inline]
fn get_cached_str(&self, cache: (u16, u16)) -> &str {
// unwrap: this cannot fail at this stage
from_utf8(&self.storage[cache.0 as usize..cache.1 as usize]).unwrap()
}
}
impl Deserialize<'_, ()> for Notice {
fn deserialize_with(mut buf: Bytes, _: ()) -> Result<Self> {
// In order to support PostgreSQL 9.5 and older we need to parse the localized S field.
// Newer versions additionally come with the V field that is guaranteed to be in English.
// We thus read both versions and prefer the unlocalized one if available.
const DEFAULT_SEVERITY: PgSeverity = PgSeverity::Log;
let mut severity_v = None;
let mut severity_s = None;
let mut message = (0, 0);
let mut code = (0, 0);
// we cache the three always present fields
// this enables to keep the access time down for the fields most likely accessed
let fields = Fields { storage: &buf, offset: 0 };
for (field, v) in fields {
if message.0 != 0 && code.0 != 0 {
// stop iterating when we have the 3 fields we were looking for
// we assume V (severity) was the first field as it should be
break;
}
use std::convert::TryInto;
match field {
b'S' => {
// Discard potential errors, because the message might be localized
severity_s =
from_utf8(&buf[v.0 as usize..v.1 as usize]).unwrap().try_into().ok();
}
b'V' => {
// Propagate errors here, because V is not localized and thus we are missing a possible
// variant.
severity_v =
Some(from_utf8(&buf[v.0 as usize..v.1 as usize]).unwrap().try_into()?);
}
b'M' => {
message = v;
}
b'C' => {
code = v;
}
_ => {}
}
}
Ok(Self {
severity: severity_v.or(severity_s).unwrap_or(DEFAULT_SEVERITY),
message,
code,
storage: buf,
})
}
}
/// An iterator over each field in the Error (or Notice) response.
struct Fields<'a> {
storage: &'a [u8],
offset: u16,
}
impl<'a> Iterator for Fields<'a> {
type Item = (u8, (u16, u16));
fn next(&mut self) -> Option<Self::Item> {
// The fields in the response body are sequentially stored as [tag][string],
// ending in a final, additional [nul]
let ty = self.storage[self.offset as usize];
if ty == 0 {
return None;
}
let nul = memchr(b'\0', &self.storage[(self.offset + 1) as usize..])? as u16;
let offset = self.offset;
self.offset += nul + 2;
Some((ty, (offset + 1, offset + nul + 1)))
}
}

View File

@ -1,40 +0,0 @@
use sqlx_core::io::Serialize;
use sqlx_core::io::WriteExt;
use sqlx_core::Result;
use crate::io::PgBufMutExt;
#[derive(Debug)]
pub struct SaslInitialResponse {
pub response: String,
pub plus: bool,
}
impl Serialize<'_, ()> for SaslInitialResponse {
fn serialize_with(&self, buf: &mut Vec<u8>, _: ()) -> Result<()> {
buf.push(b'p');
buf.write_length_prefixed(|buf| {
// name of the SASL authentication mechanism that the client selected
buf.write_str_nul(if self.plus { "SCRAM-SHA-256-PLUS" } else { "SCRAM-SHA-256" });
buf.extend(&(self.response.as_bytes().len() as i32).to_be_bytes());
buf.extend(self.response.as_bytes());
});
Ok(())
}
}
#[derive(Debug)]
pub struct SaslResponse<'a>(pub &'a str);
impl Serialize<'_, ()> for SaslResponse<'_> {
fn serialize_with(&self, buf: &mut Vec<u8>, _: ()) -> Result<()> {
buf.push(b'p');
buf.write_length_prefixed(|buf| {
buf.extend(self.0.as_bytes());
});
Ok(())
}
}

View File

@ -1,64 +0,0 @@
use sqlx_core::io::Serialize;
use sqlx_core::io::WriteExt;
use sqlx_core::Result;
use crate::io::PgBufMutExt;
// To begin a session, a frontend opens a connection to the server and sends a startup message.
// This message includes the names of the user and of the database the user wants to connect to;
// it also identifies the particular protocol version to be used.
// Optionally, the startup message can include additional settings for run-time parameters.
#[derive(Debug)]
pub struct Startup<'a> {
/// The database user name to connect as. Required; there is no default.
pub username: Option<&'a str>,
/// The database to connect to. Defaults to the user name.
pub database: Option<&'a str>,
/// Additional start-up params.
/// <https://www.postgresql.org/docs/devel/runtime-config-client.html>
pub params: &'a [(&'a str, &'a str)],
}
impl Serialize<'_, ()> for Startup<'_> {
fn serialize_with(&self, buf: &mut Vec<u8>, _: ()) -> Result<()> {
buf.reserve(120);
buf.write_length_prefixed(|buf| {
// The protocol version number. The most significant 16 bits are the
// major version number (3 for the protocol described here). The least
// significant 16 bits are the minor version number (0
// for the protocol described here)
buf.extend(&196_608_i32.to_be_bytes());
if let Some(username) = self.username {
// The database user name to connect as.
encode_startup_param(buf, "user", username);
}
if let Some(database) = self.database {
// The database to connect to. Defaults to the user name.
encode_startup_param(buf, "database", database);
}
for (name, value) in self.params {
encode_startup_param(buf, name, value);
}
// A zero byte is required as a terminator
// after the last name/value pair.
buf.push(0);
});
Ok(())
}
}
#[inline]
fn encode_startup_param(buf: &mut Vec<u8>, name: &str, value: &str) {
buf.write_str_nul(name);
buf.write_str_nul(value);
}

View File

@ -1,14 +0,0 @@
use sqlx_core::io::Serialize;
use sqlx_core::Result;
#[derive(Debug)]
pub struct Terminate;
impl Serialize<'_, ()> for Terminate {
fn serialize_with(&self, buf: &mut Vec<u8>, _: ()) -> Result<()> {
buf.push(b'X');
buf.extend(&4_u32.to_be_bytes());
Ok(())
}
}

View File

@ -0,0 +1,77 @@
use std::convert::TryInto;
use std::fmt::{self, Debug, Formatter};
use std::str::Utf8Error;
use bytes::Bytes;
use bytestring::ByteString;
use memchr::memrchr;
use sqlx_core::QueryResult;
// TODO: add unit tests for command tag parsing
/// Represents the execution result of a command in Postgres.
///
/// Returned from [`execute()`][sqlx_core::Executor::execute].
///
#[allow(clippy::module_name_repetitions)]
#[derive(Clone)]
pub struct PgQueryResult {
command: ByteString,
rows_affected: u64,
}
impl PgQueryResult {
pub(crate) fn parse(mut command: Bytes) -> Result<Self, Utf8Error> {
// look backwards for the first SPACE
let offset = memrchr(b' ', &command);
let rows = if let Some(offset) = offset {
atoi::atoi(&command.split_off(offset).slice(1..)).unwrap_or(0)
} else {
0
};
Ok(Self { command: command.try_into()?, rows_affected: rows })
}
/// Returns the command tag.
///
/// This is usually a single word that identifies which SQL command
/// was completed (e.g.,`INSERT`, `UPDATE`, or `MOVE`).
#[must_use]
pub fn command(&self) -> &str {
&self.command
}
/// Returns the number of rows inserted, deleted, updated, retrieved,
/// changed, or copied by the SQL command.
#[must_use]
pub const fn rows_affected(&self) -> u64 {
self.rows_affected
}
}
impl Debug for PgQueryResult {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("PgQueryResult")
.field("command", &self.command())
.field("rows_affected", &self.rows_affected())
.finish()
}
}
impl Extend<PgQueryResult> for PgQueryResult {
fn extend<T: IntoIterator<Item = PgQueryResult>>(&mut self, iter: T) {
for res in iter {
self.rows_affected += res.rows_affected;
self.command = res.command;
}
}
}
impl QueryResult for PgQueryResult {
#[inline]
fn rows_affected(&self) -> u64 {
self.rows_affected()
}
}

View File

@ -0,0 +1,55 @@
use bytes::Bytes;
use sqlx_core::RawValue;
use crate::{PgTypeInfo, Postgres};
/// The format of a raw SQL value for Postgres.
///
/// Postgres returns values in [`Text`] or [`Binary`] format with a
/// configuration option in a prepared query. SQLx currently hard-codes that
/// option to [`Binary`].
///
/// For simple queries, postgres only can return values in [`Text`] format.
///
#[derive(Debug, PartialEq, Copy, Clone)]
pub enum PgRawValueFormat {
Binary,
Text,
}
/// The raw representation of a SQL value for Postgres.
// 'r: row
#[derive(Debug, Clone)]
#[allow(clippy::module_name_repetitions)]
pub struct PgRawValue<'r> {
value: Option<&'r Bytes>,
format: PgRawValueFormat,
type_info: PgTypeInfo,
}
// 'r: row
impl<'r> PgRawValue<'r> {
/// Returns the type information for this value.
#[must_use]
pub const fn type_info(&self) -> &PgTypeInfo {
&self.type_info
}
/// Returns the format of this value.
#[must_use]
pub const fn format(&self) -> PgRawValueFormat {
self.format
}
}
impl<'r> RawValue<'r> for PgRawValue<'r> {
type Database = Postgres;
fn is_null(&self) -> bool {
self.value.is_none()
}
fn type_info(&self) -> &PgTypeInfo {
&self.type_info
}
}

47
sqlx-postgres/src/row.rs Normal file
View File

@ -0,0 +1,47 @@
use sqlx_core::{ColumnIndex, Result, Row};
use crate::{PgColumn, PgRawValue, Postgres};
/// A single row from a result set generated from MySQL.
#[allow(clippy::module_name_repetitions)]
pub struct PgRow {}
impl Row for PgRow {
type Database = Postgres;
fn is_null(&self) -> bool {
// self.is_null()
todo!()
}
fn len(&self) -> usize {
// self.len()
todo!()
}
fn columns(&self) -> &[PgColumn] {
// self.columns()
todo!()
}
fn try_column<I: ColumnIndex<Self>>(&self, index: I) -> Result<&PgColumn> {
// self.try_column(index)
todo!()
}
fn column_name(&self, index: usize) -> Option<&str> {
// self.columns.get(index).map(PgColumn::name)
todo!()
}
fn column_index(&self, name: &str) -> Option<usize> {
// self.columns.iter().position(|col| col.name() == name)
todo!()
}
#[allow(clippy::needless_lifetimes)]
fn try_get_raw<'r, I: ColumnIndex<Self>>(&'r self, index: I) -> Result<PgRawValue<'r>> {
// self.try_get_raw(index)
todo!()
}
}

161
sqlx-postgres/src/stream.rs Normal file
View File

@ -0,0 +1,161 @@
use std::convert::TryInto;
use std::fmt::Debug;
use std::ops::{Deref, DerefMut};
use bytes::Buf;
use sqlx_core::io::{BufStream, Serialize};
use sqlx_core::net::Stream as NetStream;
use sqlx_core::{Result, Runtime};
use crate::protocol::{BackendMessage, BackendMessageType};
/// Reads and writes messages to and from the PostgreSQL database server.
///
/// The logic for serializing data structures into the messages is found
/// mostly in `protocol/`.
///
/// The first byte of a message identifies the message type, and the next
/// four bytes specify the length of the rest of the message (this length
/// count includes itself, but not the message-type byte). The remaining
/// contents of the message are determined by the message type. For
/// historical reasons, the very first message sent by the client (
/// the startup message) has no initial message-type byte.
///
/// <https://dev.postgres.com/doc/internals/en/postgres-packet.html>
///
#[allow(clippy::module_name_repetitions)]
pub(crate) struct PgStream<Rt: Runtime> {
stream: BufStream<Rt, NetStream<Rt>>,
}
impl<Rt: Runtime> PgStream<Rt> {
pub(crate) fn new(stream: NetStream<Rt>) -> Self {
Self { stream: BufStream::with_capacity(stream, 4096, 1024) }
}
// all communication is through a stream of messages
pub(crate) fn write_message<'ser, T>(&'ser mut self, message: &T) -> Result<()>
where
T: Serialize<'ser> + Debug,
{
Ok(())
}
// reads and consumes a message from the stream buffer
// assumes there is a message on the stream
fn read_message(&mut self, size: usize) -> Result<Option<BackendMessage>> {
// the first byte is the message type
let ty = self.stream.get(0, 1).get_u8();
let ty: BackendMessageType = ty.try_into()?;
// the next 4 bytes was the length of the message
self.stream.consume(5);
// and now take the message contents
let contents = self.stream.take(size);
if contents.len() != size {
// TODO: return a database error
// BUG: something is very wrong somewhere if this branch is executed
// either in the SQLx Postgres driver or in the Postgres server
unimplemented!(
"Received {} bytes for packet but expecting {} bytes",
contents.len(),
size
);
}
match ty {
BackendMessageType::ErrorResponse => {
// TODO: return a proper error
unimplemented!("error response");
}
BackendMessageType::NotificationResponse => {
// TODO: handle these similar to master
Ok(None)
}
BackendMessageType::NoticeResponse => {
// TODO: log the incoming message
Ok(None)
}
BackendMessageType::ParameterStatus => {
// TODO: pull out and remember server version
Ok(None)
}
_ => Ok(Some(BackendMessage { contents, r#type: ty })),
}
}
}
macro_rules! impl_read_message {
($(@$blocking:ident)? $self:ident) => {{
Ok(loop {
// reads at least 5 bytes from the IO stream into the read buffer
impl_read_message!($(@$blocking)? @stream $self, 0, 5);
// bytes 1..4 will be the length of the message
let size = ($self.stream.get(1, 4).get_u32() - 4) as usize;
// read <size> bytes _after_ the header
impl_read_message!($(@$blocking)? @stream $self, 4, size);
if let Some(message) = $self.read_message(size)? {
break message;
}
})
}};
(@blocking @stream $self:ident, $offset:expr, $n:expr) => {
$self.stream.read($offset, $n)?;
};
(@stream $self:ident, $offset:expr, $n:expr) => {
$self.stream.read_async($offset, $n).await?;
};
}
impl<Rt: Runtime> PgStream<Rt> {
#[cfg(feature = "async")]
pub(crate) async fn read_message_async(&mut self) -> Result<BackendMessage>
where
Rt: sqlx_core::Async,
{
impl_read_message!(self)
}
#[cfg(feature = "blocking")]
pub(crate) fn read_message_blocking(&mut self) -> Result<BackendMessage>
where
Rt: sqlx_core::blocking::Runtime,
{
impl_read_message!(@blocking self)
}
}
impl<Rt: Runtime> Deref for PgStream<Rt> {
type Target = BufStream<Rt, NetStream<Rt>>;
fn deref(&self) -> &Self::Target {
&self.stream
}
}
impl<Rt: Runtime> DerefMut for PgStream<Rt> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.stream
}
}
macro_rules! read_message {
(@blocking $stream:expr) => {
$stream.read_message_blocking()?
};
($stream:expr) => {
$stream.read_message_async().await?
};
}

View File

@ -0,0 +1,120 @@
/// A unique identifier for a Postgres data type.
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(
any(feature = "offline", feature = "serde"),
derive(serde::Serialize, serde::Deserialize)
)]
#[allow(clippy::module_name_repetitions)]
pub enum PgTypeId {
Oid(u32),
Name(&'static str),
}
// Data Types
// https://www.postgresql.org/docs/current/datatype.html
impl PgTypeId {
// Boolean
// https://www.postgresql.org/docs/current/datatype-boolean.html
/// The SQL standard `boolean` type.
///
/// Maps to `bool`.
///
pub const BOOLEAN: Self = Self::Oid(16);
// Integers
// https://www.postgresql.org/docs/current/datatype-numeric.html#DATATYPE-INT
/// A 2-byte integer.
///
/// Compatible with any primitive integer type.
///
/// Maps to `i16`.
///
#[doc(alias = "INT2")]
#[doc(alias = "SMALLSERIAL")]
pub const SMALLINT: Self = Self::Oid(21);
/// A 4-byte integer.
///
/// Compatible with any primitive integer type.
///
/// Maps to `i32`.
///
#[doc(alias = "INT4")]
#[doc(alias = "SERIAL")]
pub const INTEGER: Self = Self::Oid(23);
/// An 8-byte integer.
///
/// Compatible with any primitive integer type.
///
/// Maps to `i64`.
///
#[doc(alias = "INT8")]
#[doc(alias = "BIGSERIAL")]
pub const BIGINT: Self = Self::Oid(20);
// Arbitrary Precision Numbers
// https://www.postgresql.org/docs/current/datatype-numeric.html#DATATYPE-NUMERIC-DECIMAL
/// An exact numeric type with a user-specified precision.
///
/// Compatible with [`bigdecimal::BigDecimal`], [`rust_decimal::Decimal`], [`num_int::BigInt`], and any
/// primitive integer type. Truncation or loss-of-precision is considered an error
/// when decoding into the selected Rust integer type.
///
/// With a scale of `0` (e.g, `NUMERIC(17, 0)`), maps to `num_int::BigInt`; otherwise,
/// maps to [`bigdecimal::BigDecimal`] or [`rust_decimal::Decimal`] (depending on
/// enabled crate features).
///
#[doc(alias = "DECIMAL")]
pub const NUMERIC: Self = Self::Oid(1700);
// Floating-Point
// https://www.postgresql.org/docs/current/datatype-numeric.html#DATATYPE-FLOAT
/// A 4-byte floating-point numeric type.
///
/// Compatible with `f32` or `f64`.
///
/// Maps to `f32`.
///
#[doc(alias = "FLOAT4")]
pub const REAL: Self = Self::Oid(700);
/// An 8-byte floating-point numeric type.
///
/// Compatible with `f32` or `f64`.
///
/// Maps to `f64`.
///
#[doc(alias = "FLOAT8")]
pub const DOUBLE: Self = Self::Oid(701);
/// The `UNKNOWN` Postgres type. Returned for expressions that do not
/// have a type (e.g., `SELECT $1` with no parameter type hint
/// or `SELECT NULL`).
pub const UNKNOWN: Self = Self::Oid(705);
}
impl PgTypeId {
#[must_use]
pub(crate) const fn name(self) -> &'static str {
match self {
Self::BOOLEAN => "BOOLEAN",
Self::SMALLINT => "SMALLINT",
Self::INTEGER => "INTEGER",
Self::BIGINT => "BIGINT",
Self::NUMERIC => "NUMERIC",
Self::REAL => "REAL",
Self::DOUBLE => "DOUBLE",
_ => "UNKNOWN",
}
}
}

View File

@ -0,0 +1,42 @@
use sqlx_core::TypeInfo;
use crate::{PgTypeId, Postgres};
/// Provides information about a Postgres type.
#[derive(Debug, Clone, Copy)]
#[cfg_attr(
any(feature = "offline", feature = "serde"),
derive(serde::Serialize, serde::Deserialize)
)]
#[allow(clippy::module_name_repetitions)]
pub struct PgTypeInfo(pub(crate) PgTypeId);
impl PgTypeInfo {
/// Returns the unique identifier for this Postgres type.
#[must_use]
pub const fn id(&self) -> PgTypeId {
self.0
}
/// Returns the name for this Postgres type.
#[must_use]
pub const fn name(&self) -> &'static str {
self.0.name()
}
}
impl TypeInfo for PgTypeInfo {
type Database = Postgres;
fn id(&self) -> PgTypeId {
self.id()
}
fn is_unknown(&self) -> bool {
matches!(self.0, PgTypeId::UNKNOWN)
}
fn name(&self) -> &'static str {
self.name()
}
}

View File

@ -0,0 +1 @@