mirror of
https://github.com/launchbadge/sqlx.git
synced 2025-12-29 04:04:56 +00:00
refactor(postgres): baseline postgres driver against the now near-complete state of the mysql driver
This commit is contained in:
parent
39e2658537
commit
baa63d33e1
76
Cargo.lock
generated
76
Cargo.lock
generated
@ -379,16 +379,6 @@ dependencies = [
|
||||
"lazy_static",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crypto-mac"
|
||||
version = "0.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4857fd85a0c34b3c3297875b747c1e02e06b6a0ea32dd892d8192b9ce0813ea6"
|
||||
dependencies = [
|
||||
"generic-array",
|
||||
"subtle",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ctor"
|
||||
version = "0.1.19"
|
||||
@ -565,16 +555,6 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hmac"
|
||||
version = "0.10.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c1441c6b1e930e2817404b5046f1f989899143a12bf92de603b69f4e0aee1e15"
|
||||
dependencies = [
|
||||
"crypto-mac",
|
||||
"digest",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "idna"
|
||||
version = "0.2.2"
|
||||
@ -595,12 +575,6 @@ dependencies = [
|
||||
"cfg-if 1.0.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itoa"
|
||||
version = "0.4.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dd25036021b0de88a0aff6b850051563c6516d0bf53f8638938edbb9de732736"
|
||||
|
||||
[[package]]
|
||||
name = "js-sys"
|
||||
version = "0.3.48"
|
||||
@ -630,9 +604,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.87"
|
||||
version = "0.2.88"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "265d751d31d6780a3f956bb5b8022feba2d94eeee5a84ba64f4212eedca42213"
|
||||
checksum = "03b07a082330a35e43f63177cc01689da34fbffa0105e1246cf0311472cac73a"
|
||||
|
||||
[[package]]
|
||||
name = "libm"
|
||||
@ -665,17 +639,6 @@ version = "0.1.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7ffc5c5338469d4d3ea17d269fa8ea3512ad247247c30bd2df69e68309ed0a08"
|
||||
|
||||
[[package]]
|
||||
name = "md-5"
|
||||
version = "0.9.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7b5a279bb9607f9f53c22d496eade00d138d1bdcccd07d74650387cf94942a15"
|
||||
dependencies = [
|
||||
"block-buffer",
|
||||
"digest",
|
||||
"opaque-debug",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "memchr"
|
||||
version = "2.3.4"
|
||||
@ -865,9 +828,9 @@ checksum = "d4fd5641d01c8f18a23da7b6fe29298ff4b55afcccdf78973b24cf3175fee32e"
|
||||
|
||||
[[package]]
|
||||
name = "pin-project-lite"
|
||||
version = "0.2.5"
|
||||
version = "0.2.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0cf491442e4b033ed1c722cb9f0df5fcfcf4de682466c46469c36bc47dc5548a"
|
||||
checksum = "dc0e1f259c92177c30a4c9d177246edd0a3568b25756a977d0632cf8fa37e905"
|
||||
|
||||
[[package]]
|
||||
name = "pin-utils"
|
||||
@ -1019,9 +982,9 @@ checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd"
|
||||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.123"
|
||||
version = "1.0.124"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "92d5161132722baa40d802cc70b15262b98258453e85e5d1d365c757c73869ae"
|
||||
checksum = "bd761ff957cb2a45fbb9ab3da6512de9de55872866160b23c25f1a841e99d29f"
|
||||
|
||||
[[package]]
|
||||
name = "sha-1"
|
||||
@ -1159,41 +1122,20 @@ version = "0.6.0-pre"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"atoi",
|
||||
"base64",
|
||||
"bitflags",
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"bytestring",
|
||||
"crypto-mac",
|
||||
"either",
|
||||
"conquer-once",
|
||||
"futures-executor",
|
||||
"futures-io",
|
||||
"futures-util",
|
||||
"hmac",
|
||||
"itoa",
|
||||
"log",
|
||||
"md-5",
|
||||
"memchr",
|
||||
"percent-encoding",
|
||||
"rand",
|
||||
"rsa",
|
||||
"sha-1",
|
||||
"sha2",
|
||||
"sqlx-core",
|
||||
"stringprep",
|
||||
"url",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "stringprep"
|
||||
version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8ee348cb74b87454fff4b551cbf727025810a004f88aeacae7f85b87f4e9a1c1"
|
||||
dependencies = [
|
||||
"unicode-bidi",
|
||||
"unicode-normalization",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "subtle"
|
||||
version = "2.4.0"
|
||||
@ -1202,9 +1144,9 @@ checksum = "1e81da0851ada1f3e9d4312c704aa4f8806f0f9d69faaf8df2f3464b4a9437c2"
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "1.0.60"
|
||||
version = "1.0.62"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c700597eca8a5a762beb35753ef6b94df201c81cca676604f547495a0d7f0081"
|
||||
checksum = "123a78a3596b24fee53a6464ce52d8ecbf62241e6294c7e7fe12086cd161f512"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
|
||||
@ -9,5 +9,5 @@ pub trait RawValue<'r>: Sized {
|
||||
fn is_null(&self) -> bool;
|
||||
|
||||
/// Returns the type information for this value.
|
||||
fn type_info(&self) -> &'r <Self::Database as Database>::TypeInfo;
|
||||
fn type_info(&self) -> &<Self::Database as Database>::TypeInfo;
|
||||
}
|
||||
|
||||
@ -23,7 +23,6 @@ default = []
|
||||
blocking = ["sqlx-core/blocking"]
|
||||
|
||||
# async runtime
|
||||
# not meant to be used directly
|
||||
async = ["futures-util", "sqlx-core/async", "futures-io"]
|
||||
|
||||
[dependencies]
|
||||
|
||||
@ -12,7 +12,7 @@ use crate::protocol::{Capabilities, Info, Status};
|
||||
/// An OK packet is sent from the server to the client to signal successful completion of a command.
|
||||
/// As of MySQL 5.7.5, OK packes are also used to indicate EOF, and EOF packets are deprecated.
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct OkPacket {
|
||||
pub(crate) affected_rows: u64,
|
||||
pub(crate) last_insert_id: u64,
|
||||
|
||||
@ -4,11 +4,12 @@ use sqlx_core::QueryResult;
|
||||
|
||||
use crate::protocol::{Info, OkPacket, Status};
|
||||
|
||||
/// Represents the execution result of an operation on the database server.
|
||||
/// Represents the execution result of an operation in MySQL.
|
||||
///
|
||||
/// Returned from [`execute()`][sqlx_core::Executor::execute].
|
||||
///
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
#[derive(Clone)]
|
||||
pub struct MySqlQueryResult(pub(crate) OkPacket);
|
||||
|
||||
impl MySqlQueryResult {
|
||||
@ -109,18 +110,6 @@ impl Debug for MySqlQueryResult {
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for MySqlQueryResult {
|
||||
fn default() -> Self {
|
||||
Self(OkPacket {
|
||||
affected_rows: 0,
|
||||
last_insert_id: 0,
|
||||
status: Status::empty(),
|
||||
warnings: 0,
|
||||
info: Info::default(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl From<OkPacket> for MySqlQueryResult {
|
||||
fn from(ok: OkPacket) -> Self {
|
||||
Self(ok)
|
||||
|
||||
@ -87,7 +87,7 @@ impl<'r> RawValue<'r> for MySqlRawValue<'r> {
|
||||
self.value.is_none()
|
||||
}
|
||||
|
||||
fn type_info(&self) -> &'r MySqlTypeInfo {
|
||||
fn type_info(&self) -> &MySqlTypeInfo {
|
||||
self.type_info
|
||||
}
|
||||
}
|
||||
|
||||
@ -101,7 +101,7 @@ impl MySqlTypeInfo {
|
||||
MySqlTypeId::CHAR => "CHAR",
|
||||
MySqlTypeId::TEXT => "TEXT",
|
||||
|
||||
_ => "",
|
||||
_ => "UNKNOWN",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
name = "sqlx-postgres"
|
||||
version = "0.6.0-pre"
|
||||
repository = "https://github.com/launchbadge/sqlx"
|
||||
description = "MySQL database driver for SQLx, the Rust SQL Toolkit."
|
||||
description = "PostgreSQL database driver for SQLx, the Rust SQL Toolkit."
|
||||
license = "MIT OR Apache-2.0"
|
||||
edition = "2018"
|
||||
keywords = ["postgres", "sqlx", "database"]
|
||||
@ -23,13 +23,12 @@ default = []
|
||||
blocking = ["sqlx-core/blocking"]
|
||||
|
||||
# async runtime
|
||||
# not meant to be used directly
|
||||
async = ["futures-util", "sqlx-core/async", "futures-io"]
|
||||
|
||||
[dependencies]
|
||||
atoi = "0.4.0"
|
||||
sqlx-core = { version = "0.6.0-pre", path = "../sqlx-core" }
|
||||
futures-util = { version = "0.3.8", optional = true }
|
||||
either = "1.6.1"
|
||||
log = "0.4.11"
|
||||
bytestring = "1.0.0"
|
||||
url = "2.2.0"
|
||||
@ -38,20 +37,9 @@ futures-io = { version = "0.3", optional = true }
|
||||
bytes = "1.0"
|
||||
memchr = "2.3"
|
||||
bitflags = "1.2"
|
||||
sha-1 = "0.9.2"
|
||||
sha2 = "0.9.2"
|
||||
rsa = "0.3.0"
|
||||
base64 = "0.13.0"
|
||||
rand = "0.7"
|
||||
itoa = "0.4.7"
|
||||
atoi = "0.4.0"
|
||||
byteorder = { version = "1.4.2", default-features = false, features = [ "std" ] }
|
||||
md-5 = { version = "0.9.1", default-features = false }
|
||||
hmac = { version = "0.10.1", default-features = false }
|
||||
stringprep = "0.1.2"
|
||||
crypto-mac = "0.10.0"
|
||||
|
||||
[dev-dependencies]
|
||||
sqlx-core = { version = "0.6.0-pre", path = "../sqlx-core", features = ["_mock"] }
|
||||
futures-executor = "0.3.8"
|
||||
anyhow = "1.0.37"
|
||||
conquer-once = "0.3.2"
|
||||
|
||||
31
sqlx-postgres/src/column.rs
Normal file
31
sqlx-postgres/src/column.rs
Normal file
@ -0,0 +1,31 @@
|
||||
use bytestring::ByteString;
|
||||
use sqlx_core::{Column, Database};
|
||||
|
||||
use crate::{PgTypeInfo, Postgres};
|
||||
|
||||
// TODO: inherent methods from <Column>
|
||||
|
||||
/// Represents a column from a query in Postgres.
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PgColumn {
|
||||
index: usize,
|
||||
name: ByteString,
|
||||
type_info: PgTypeInfo,
|
||||
}
|
||||
|
||||
impl Column for PgColumn {
|
||||
type Database = Postgres;
|
||||
|
||||
fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
fn index(&self) -> usize {
|
||||
self.index
|
||||
}
|
||||
|
||||
fn type_info(&self) -> &PgTypeInfo {
|
||||
&self.type_info
|
||||
}
|
||||
}
|
||||
@ -1,124 +0,0 @@
|
||||
use std::fmt::{self, Debug, Formatter};
|
||||
|
||||
use sqlx_core::io::BufStream;
|
||||
use sqlx_core::net::Stream as NetStream;
|
||||
use sqlx_core::{Close, Connect, Connection, Runtime};
|
||||
|
||||
use crate::{Postgres, PostgresConnectOptions};
|
||||
|
||||
#[macro_use]
|
||||
mod sasl;
|
||||
|
||||
mod close;
|
||||
mod connect;
|
||||
mod ping;
|
||||
mod stream;
|
||||
|
||||
/// A single connection (also known as a session) to a PostgreSQL database server.
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
pub struct PostgresConnection<Rt>
|
||||
where
|
||||
Rt: Runtime,
|
||||
{
|
||||
stream: BufStream<Rt, NetStream<Rt>>,
|
||||
|
||||
// process id of this backend
|
||||
// used to send cancel requests
|
||||
#[allow(dead_code)]
|
||||
process_id: u32,
|
||||
|
||||
// secret key of this backend
|
||||
// used to send cancel requests
|
||||
#[allow(dead_code)]
|
||||
secret_key: u32,
|
||||
}
|
||||
|
||||
impl<Rt> PostgresConnection<Rt>
|
||||
where
|
||||
Rt: Runtime,
|
||||
{
|
||||
pub(crate) fn new(stream: NetStream<Rt>) -> Self {
|
||||
Self { stream: BufStream::with_capacity(stream, 4096, 1024), process_id: 0, secret_key: 0 }
|
||||
}
|
||||
}
|
||||
|
||||
impl<Rt> Debug for PostgresConnection<Rt>
|
||||
where
|
||||
Rt: Runtime,
|
||||
{
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("PostgresConnection").finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<Rt> Connection<Rt> for PostgresConnection<Rt>
|
||||
where
|
||||
Rt: Runtime,
|
||||
{
|
||||
type Database = Postgres;
|
||||
|
||||
#[cfg(feature = "async")]
|
||||
fn ping(&mut self) -> futures_util::future::BoxFuture<'_, sqlx_core::Result<()>>
|
||||
where
|
||||
Rt: sqlx_core::Async,
|
||||
{
|
||||
Box::pin(self.ping_async())
|
||||
}
|
||||
}
|
||||
|
||||
impl<Rt: Runtime> Connect<Rt> for PostgresConnection<Rt> {
|
||||
type Options = PostgresConnectOptions<Rt>;
|
||||
|
||||
#[cfg(feature = "async")]
|
||||
fn connect(url: &str) -> futures_util::future::BoxFuture<'_, sqlx_core::Result<Self>>
|
||||
where
|
||||
Self: Sized,
|
||||
Rt: sqlx_core::Async,
|
||||
{
|
||||
use sqlx_core::ConnectOptions;
|
||||
|
||||
let options = url.parse::<Self::Options>();
|
||||
Box::pin(async move { options?.connect().await })
|
||||
}
|
||||
}
|
||||
|
||||
impl<Rt: Runtime> Close<Rt> for PostgresConnection<Rt> {
|
||||
#[cfg(feature = "async")]
|
||||
fn close(self) -> futures_util::future::BoxFuture<'static, sqlx_core::Result<()>>
|
||||
where
|
||||
Rt: sqlx_core::Async,
|
||||
{
|
||||
Box::pin(self.close_async())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "blocking")]
|
||||
mod blocking {
|
||||
use sqlx_core::blocking::{Close, Connect, Connection, Runtime};
|
||||
|
||||
use super::{PostgresConnectOptions, PostgresConnection};
|
||||
|
||||
impl<Rt: Runtime> Connection<Rt> for PostgresConnection<Rt> {
|
||||
#[inline]
|
||||
fn ping(&mut self) -> sqlx_core::Result<()> {
|
||||
self.ping()
|
||||
}
|
||||
}
|
||||
|
||||
impl<Rt: Runtime> Connect<Rt> for PostgresConnection<Rt> {
|
||||
#[inline]
|
||||
fn connect(url: &str) -> sqlx_core::Result<Self>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
Self::connect(&url.parse::<PostgresConnectOptions<Rt>>()?)
|
||||
}
|
||||
}
|
||||
|
||||
impl<Rt: Runtime> Close<Rt> for PostgresConnection<Rt> {
|
||||
#[inline]
|
||||
fn close(self) -> sqlx_core::Result<()> {
|
||||
self.close()
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,32 +0,0 @@
|
||||
use sqlx_core::{io::Stream, Result, Runtime};
|
||||
|
||||
use crate::protocol::Terminate;
|
||||
|
||||
impl<Rt> super::PostgresConnection<Rt>
|
||||
where
|
||||
Rt: Runtime,
|
||||
{
|
||||
#[cfg(feature = "async")]
|
||||
pub(crate) async fn close_async(mut self) -> Result<()>
|
||||
where
|
||||
Rt: sqlx_core::Async,
|
||||
{
|
||||
self.write_packet(&Terminate)?;
|
||||
self.stream.flush_async().await?;
|
||||
self.stream.shutdown_async().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(feature = "blocking")]
|
||||
pub(crate) fn close(mut self) -> Result<()>
|
||||
where
|
||||
Rt: sqlx_core::blocking::Runtime,
|
||||
{
|
||||
self.write_packet(&Terminate)?;
|
||||
self.stream.flush()?;
|
||||
self.stream.shutdown()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@ -1,180 +0,0 @@
|
||||
//! Implements the connection phase.
|
||||
//!
|
||||
//! The connection phase (establish) performs these tasks:
|
||||
//!
|
||||
//! - exchange the capabilities of client and server
|
||||
//! - setup SSL communication channel if requested
|
||||
//! - authenticate the client against the server
|
||||
//!
|
||||
//! The server may immediately send an ERR packet and finish the handshake
|
||||
//! or send a `Handshake`.
|
||||
//!
|
||||
//! https://dev.postgres.com/doc/internals/en/connection-phase.html
|
||||
//!
|
||||
use hmac::{Hmac, Mac, NewMac};
|
||||
use sha2::{Digest, Sha256};
|
||||
use sqlx_core::net::Stream as NetStream;
|
||||
use sqlx_core::Error;
|
||||
use sqlx_core::Result;
|
||||
|
||||
use crate::protocol::{
|
||||
Authentication, BackendKeyData, Message, MessageType, Password, ReadyForQuery,
|
||||
SaslInitialResponse, SaslResponse, Startup,
|
||||
};
|
||||
use crate::{PostgresConnectOptions, PostgresConnection, PostgresDatabaseError};
|
||||
|
||||
macro_rules! connect {
|
||||
(@blocking @tcp $options:ident) => {
|
||||
NetStream::connect($options.address.as_ref())?;
|
||||
};
|
||||
|
||||
(@tcp $options:ident) => {
|
||||
NetStream::connect_async($options.address.as_ref()).await?;
|
||||
};
|
||||
|
||||
(@blocking @packet $self:ident) => {
|
||||
$self.read_packet()?;
|
||||
};
|
||||
|
||||
(@packet $self:ident) => {
|
||||
$self.read_packet_async().await?;
|
||||
};
|
||||
|
||||
($(@$blocking:ident)? $options:ident) => {{
|
||||
// open a network stream to the database server
|
||||
let stream = connect!($(@$blocking)? @tcp $options);
|
||||
|
||||
// construct a <PostgresConnection> around the network stream
|
||||
// wraps the stream in a <BufStream> to buffer read and write
|
||||
let mut self_ = Self::new(stream);
|
||||
|
||||
// To begin a session, a frontend opens a connection to the server
|
||||
// and sends a startup message.
|
||||
|
||||
let mut params = vec![ // Sets the display format for date and time values, as well as the rules for interpreting ambiguous date input values. ("DateStyle", "ISO, MDY"),
|
||||
// Sets the client-side encoding (character set).
|
||||
// <https://www.postgresql.org/docs/devel/multibyte.html#MULTIBYTE-CHARSET-SUPPORTED>
|
||||
("client_encoding", "UTF8"),
|
||||
// Sets the time zone for displaying and interpreting time stamps.
|
||||
("TimeZone", "UTC"),
|
||||
// Adjust postgres to return precise values for floats
|
||||
// NOTE: This is default in postgres 12+
|
||||
("extra_float_digits", "3"),
|
||||
];
|
||||
|
||||
// if let Some(ref application_name) = $options.get_application_name() {
|
||||
// params.push(("application_name", application_name));
|
||||
// }
|
||||
|
||||
self_.write_packet(&Startup {
|
||||
username: $options.get_username(),
|
||||
database: $options.get_database(),
|
||||
params: ¶ms,
|
||||
})?;
|
||||
|
||||
// The server then uses this information and the contents of
|
||||
// its configuration files (such as pg_hba.conf) to determine whether the connection is
|
||||
// provisionally acceptable, and what additional
|
||||
// authentication is required (if any).
|
||||
|
||||
let mut process_id = 0;
|
||||
let mut secret_key = 0;
|
||||
let transaction_status;
|
||||
|
||||
loop {
|
||||
let message: Message = connect!($(@$blocking)? @packet self_);
|
||||
match message.r#type {
|
||||
MessageType::Authentication => match message.decode()? {
|
||||
Authentication::Ok => {
|
||||
// the authentication exchange is successfully completed
|
||||
// do nothing; no more information is required to continue
|
||||
}
|
||||
|
||||
Authentication::CleartextPassword => {
|
||||
// The frontend must now send a [PasswordMessage] containing the
|
||||
// password in clear-text form.
|
||||
|
||||
self_
|
||||
.write_packet(&Password::Cleartext(
|
||||
$options.get_password().unwrap_or_default(),
|
||||
))?;
|
||||
}
|
||||
|
||||
Authentication::Md5Password(body) => {
|
||||
// The frontend must now send a [PasswordMessage] containing the
|
||||
// password (with user name) encrypted via MD5, then encrypted again
|
||||
// using the 4-byte random salt specified in the
|
||||
// [AuthenticationMD5Password] message.
|
||||
|
||||
self_
|
||||
.write_packet(&Password::Md5 {
|
||||
username: $options.get_username().unwrap_or_default(),
|
||||
password: $options.get_password().unwrap_or_default(),
|
||||
salt: body.salt,
|
||||
})?;
|
||||
}
|
||||
|
||||
Authentication::Sasl(body) => {
|
||||
sasl_authenticate!($(@$blocking)? self_, $options, body)
|
||||
}
|
||||
|
||||
method => {
|
||||
return Err(Error::configuration_msg(format!(
|
||||
"unsupported authentication method: {:?}",
|
||||
method
|
||||
)));
|
||||
}
|
||||
},
|
||||
|
||||
MessageType::BackendKeyData => {
|
||||
// provides secret-key data that the frontend must save if it wants to be
|
||||
// able to issue cancel requests later
|
||||
|
||||
let data: BackendKeyData = message.decode()?;
|
||||
|
||||
process_id = data.process_id;
|
||||
secret_key = data.secret_key;
|
||||
}
|
||||
|
||||
MessageType::ReadyForQuery => {
|
||||
let ready: ReadyForQuery = message.decode()?;
|
||||
|
||||
// start-up is completed. The frontend can now issue commands
|
||||
transaction_status = ready.transaction_status;
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
_ => {
|
||||
return Err(Error::configuration_msg(format!(
|
||||
"establish: unexpected message: {:?}",
|
||||
message.r#type
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(self_)
|
||||
}};
|
||||
}
|
||||
|
||||
impl<Rt> PostgresConnection<Rt>
|
||||
where
|
||||
Rt: sqlx_core::Runtime,
|
||||
{
|
||||
#[cfg(feature = "async")]
|
||||
pub(crate) async fn connect_async(options: &PostgresConnectOptions<Rt>) -> Result<Self>
|
||||
where
|
||||
Rt: sqlx_core::Async,
|
||||
{
|
||||
connect!(options)
|
||||
}
|
||||
|
||||
#[cfg(feature = "blocking")]
|
||||
pub(crate) fn connect(options: &PostgresConnectOptions<Rt>) -> Result<Self>
|
||||
where
|
||||
Rt: sqlx_core::blocking::Runtime,
|
||||
{
|
||||
connect!(@blocking options)
|
||||
}
|
||||
}
|
||||
@ -1,22 +0,0 @@
|
||||
use sqlx_core::{Result, Runtime};
|
||||
|
||||
impl<Rt> super::PostgresConnection<Rt>
|
||||
where
|
||||
Rt: Runtime,
|
||||
{
|
||||
#[cfg(feature = "async")]
|
||||
pub(crate) async fn ping_async(&mut self) -> Result<()>
|
||||
where
|
||||
Rt: sqlx_core::Async,
|
||||
{
|
||||
todo!();
|
||||
}
|
||||
|
||||
#[cfg(feature = "blocking")]
|
||||
pub(crate) fn ping(&mut self) -> Result<()>
|
||||
where
|
||||
Rt: sqlx_core::blocking::Runtime,
|
||||
{
|
||||
todo!();
|
||||
}
|
||||
}
|
||||
@ -1,294 +0,0 @@
|
||||
use hmac::{Hmac, Mac, NewMac};
|
||||
use rand::Rng;
|
||||
use sha2::digest::Digest;
|
||||
use sha2::Sha256;
|
||||
use sqlx_core::Error;
|
||||
use sqlx_core::Result;
|
||||
use sqlx_core::Runtime;
|
||||
|
||||
use crate::protocol::{
|
||||
Authentication, AuthenticationSasl, Message, MessageType, SaslInitialResponse, SaslResponse,
|
||||
};
|
||||
use crate::PostgresConnectOptions;
|
||||
use crate::PostgresConnection;
|
||||
use crate::PostgresDatabaseError;
|
||||
|
||||
pub(super) const GS2_HEADER: &str = "n,,";
|
||||
pub(super) const CHANNEL_ATTR: &str = "c";
|
||||
pub(super) const USERNAME_ATTR: &str = "n";
|
||||
pub(super) const CLIENT_PROOF_ATTR: &str = "p";
|
||||
pub(super) const NONCE_ATTR: &str = "r";
|
||||
|
||||
pub(super) fn sasl_init_response<Rt: Runtime>(
|
||||
self_: &mut PostgresConnection<Rt>,
|
||||
options: &PostgresConnectOptions<Rt>,
|
||||
data: AuthenticationSasl,
|
||||
) -> Result<(String, String, String)> {
|
||||
let mut has_sasl = false;
|
||||
let mut has_sasl_plus = false;
|
||||
let mut unknown = Vec::new();
|
||||
|
||||
for mechanism in data.mechanisms() {
|
||||
match mechanism {
|
||||
"SCRAM-SHA-256" => {
|
||||
has_sasl = true;
|
||||
}
|
||||
|
||||
"SCRAM-SHA-256-PLUS" => {
|
||||
has_sasl_plus = true;
|
||||
}
|
||||
|
||||
_ => {
|
||||
unknown.push(mechanism.to_owned());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !has_sasl_plus && !has_sasl {
|
||||
return Err(Error::connect(PostgresDatabaseError::protocol(format!(
|
||||
"unsupported SASL authentication mechanisms: {}",
|
||||
unknown.join(", ")
|
||||
))));
|
||||
}
|
||||
|
||||
// channel-binding = "c=" base64
|
||||
let channel_binding = format!(
|
||||
"{}={}",
|
||||
crate::connection::sasl::CHANNEL_ATTR,
|
||||
base64::encode(crate::connection::sasl::GS2_HEADER)
|
||||
);
|
||||
|
||||
// "n=" saslname ;; Usernames are prepared using SASLprep.
|
||||
let username = format!("{}={}", USERNAME_ATTR, options.get_username().unwrap_or_default());
|
||||
let username = match stringprep::saslprep(&username) {
|
||||
Ok(v) => v,
|
||||
Err(err) => {
|
||||
return Err(Error::connect(PostgresDatabaseError::protocol(format!(
|
||||
"failed to sasl prep the username: {:?}",
|
||||
err
|
||||
))));
|
||||
}
|
||||
};
|
||||
|
||||
// nonce = "r=" c-nonce [s-nonce] ;; Second part provided by server.
|
||||
let nonce = gen_nonce();
|
||||
|
||||
// client-first-message-bare = [reserved-mext ","] username "," nonce ["," extensions]
|
||||
let client_first_message_bare =
|
||||
format!("{username},{nonce}", username = username, nonce = nonce);
|
||||
|
||||
let client_first_message = format!(
|
||||
"{gs2_header}{client_first_message_bare}",
|
||||
gs2_header = GS2_HEADER,
|
||||
client_first_message_bare = client_first_message_bare
|
||||
);
|
||||
|
||||
self_.write_packet(&SaslInitialResponse {
|
||||
response: client_first_message.clone(),
|
||||
plus: false,
|
||||
})?;
|
||||
|
||||
Ok((channel_binding, client_first_message_bare, client_first_message))
|
||||
}
|
||||
|
||||
pub(super) fn sasl_response<'a, Rt: Runtime>(
|
||||
self_: &mut PostgresConnection<Rt>,
|
||||
message: Message,
|
||||
options: &PostgresConnectOptions<Rt>,
|
||||
channel_binding: &String,
|
||||
client_first_message_bare: &'a String,
|
||||
) -> Result<Hmac<Sha256>> {
|
||||
let cont = match message.r#type {
|
||||
MessageType::Authentication => match message.decode()? {
|
||||
Authentication::SaslContinue(data) => data,
|
||||
|
||||
auth => {
|
||||
return Err(Error::connect(PostgresDatabaseError::protocol(format!(
|
||||
"expected SASLContinue but received {:?}",
|
||||
auth
|
||||
))));
|
||||
}
|
||||
},
|
||||
|
||||
r#type => {
|
||||
return Err(Error::connect(PostgresDatabaseError::protocol(format!(
|
||||
"Expected an authencation message type, found {:?}",
|
||||
r#type
|
||||
))));
|
||||
}
|
||||
};
|
||||
|
||||
// SaltedPassword := Hi(Normalize(password), salt, i)
|
||||
let salted_password =
|
||||
hi(options.get_password().unwrap_or_default(), &cont.salt, cont.iterations)?;
|
||||
|
||||
// ClientKey := HMAC(SaltedPassword, "Client Key")
|
||||
let mut mac = Hmac::<Sha256>::new_varkey(&salted_password)
|
||||
.map_err(|err| Error::connect(crate::PostgresDatabaseError::from(err)))?;
|
||||
mac.update(b"Client Key");
|
||||
|
||||
let client_key = mac.finalize().into_bytes();
|
||||
|
||||
// StoredKey := H(ClientKey)
|
||||
let stored_key = Sha256::digest(&client_key);
|
||||
|
||||
// client-final-message-without-proof
|
||||
let client_final_message_wo_proof = format!(
|
||||
"{channel_binding},r={nonce}",
|
||||
channel_binding = channel_binding,
|
||||
nonce = &cont.nonce
|
||||
);
|
||||
|
||||
// AuthMessage := client-first-message-bare + "," + server-first-message + "," + client-final-message-without-proof
|
||||
let auth_message = format!(
|
||||
"{client_first_message_bare},{server_first_message},{client_final_message_wo_proof}",
|
||||
client_first_message_bare = client_first_message_bare,
|
||||
server_first_message = cont.message,
|
||||
client_final_message_wo_proof = client_final_message_wo_proof
|
||||
);
|
||||
|
||||
// ClientSignature := HMAC(StoredKey, AuthMessage)
|
||||
let mut mac = Hmac::<Sha256>::new_varkey(&stored_key)
|
||||
.map_err(|err| Error::connect(crate::PostgresDatabaseError::from(err)))?;
|
||||
mac.update(&auth_message.as_bytes());
|
||||
|
||||
let client_signature = mac.finalize().into_bytes();
|
||||
|
||||
// ClientProof := ClientKey XOR ClientSignature
|
||||
let client_proof: Vec<u8> =
|
||||
client_key.iter().zip(client_signature.iter()).map(|(&a, &b)| a ^ b).collect();
|
||||
|
||||
// ServerKey := HMAC(SaltedPassword, "Server Key")
|
||||
let mut mac = Hmac::<Sha256>::new_varkey(&salted_password)
|
||||
.map_err(|err| Error::connect(crate::PostgresDatabaseError::from(err)))?;
|
||||
mac.update(b"Server Key");
|
||||
|
||||
let server_key = mac.finalize().into_bytes();
|
||||
|
||||
// ServerSignature := HMAC(ServerKey, AuthMessage)
|
||||
let mut mac = Hmac::<Sha256>::new_varkey(&server_key)
|
||||
.map_err(|err| Error::connect(crate::PostgresDatabaseError::from(err)))?;
|
||||
mac.update(&auth_message.as_bytes());
|
||||
|
||||
// client-final-message = client-final-message-without-proof "," proof
|
||||
let client_final_message = format!(
|
||||
"{client_final_message_wo_proof},{client_proof_attr}={client_proof}",
|
||||
client_final_message_wo_proof = client_final_message_wo_proof,
|
||||
client_proof_attr = crate::connection::sasl::CLIENT_PROOF_ATTR,
|
||||
client_proof = base64::encode(&client_proof)
|
||||
);
|
||||
|
||||
self_.write_packet(&SaslResponse(client_first_message_bare))?;
|
||||
|
||||
Ok(mac)
|
||||
}
|
||||
|
||||
pub(super) fn sasl_final(message: Message, mac: Hmac<Sha256>) -> Result<()> {
|
||||
let data = match message.r#type {
|
||||
MessageType::Authentication => match message.decode()? {
|
||||
Authentication::SaslFinal(data) => data,
|
||||
|
||||
auth => {
|
||||
return Err(Error::connect(PostgresDatabaseError::protocol(format!(
|
||||
"expected SASLContinue but received {:?}",
|
||||
auth
|
||||
))));
|
||||
}
|
||||
},
|
||||
|
||||
r#type => {
|
||||
return Err(Error::connect(PostgresDatabaseError::protocol(format!(
|
||||
"Expected an authencation message type, found {:?}",
|
||||
r#type
|
||||
))));
|
||||
}
|
||||
};
|
||||
|
||||
// authentication is only considered valid if this verification passes
|
||||
mac.verify(&data.verifier)
|
||||
.map_err(|err| Error::connect(crate::PostgresDatabaseError::from(err)))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
macro_rules! sasl_authenticate {
|
||||
(@blocking @packet $self:ident) => {
|
||||
$self.read_packet()?;
|
||||
};
|
||||
|
||||
(@packet $self:ident) => {
|
||||
$self.read_packet_async().await?;
|
||||
};
|
||||
|
||||
($(@$blocking:ident)? $self:ident, $options:ident, $data:ident) => {{
|
||||
let (
|
||||
channel_binding,
|
||||
client_first_message_bare,
|
||||
client_first_message,
|
||||
) = crate::connection::sasl::sasl_init_response(&mut $self, $options, $data)?;
|
||||
|
||||
let message: Message = sasl_authenticate!($(@$blocking)? @packet $self);
|
||||
let mac = crate::connection::sasl::sasl_response(
|
||||
&mut $self,
|
||||
message,
|
||||
$options,
|
||||
&channel_binding,
|
||||
&client_first_message_bare,
|
||||
)?;
|
||||
|
||||
let message: Message = sasl_authenticate!($(@$blocking)? @packet $self);
|
||||
crate::connection::sasl::sasl_final(
|
||||
message,
|
||||
mac
|
||||
)?;
|
||||
}};
|
||||
}
|
||||
|
||||
// nonce is a sequence of random printable bytes
|
||||
fn gen_nonce() -> String {
|
||||
let mut rng = rand::thread_rng();
|
||||
let count = rng.gen_range(64, 128);
|
||||
|
||||
// printable = %x21-2B / %x2D-7E
|
||||
// ;; Printable ASCII except ",".
|
||||
// ;; Note that any "printable" is also
|
||||
// ;; a valid "value".
|
||||
let nonce: String = std::iter::repeat(())
|
||||
.map(|()| {
|
||||
let mut c = rng.gen_range(0x21, 0x7F) as u8;
|
||||
|
||||
while c == 0x2C {
|
||||
c = rng.gen_range(0x21, 0x7F) as u8;
|
||||
}
|
||||
|
||||
c
|
||||
})
|
||||
.take(count)
|
||||
.map(|c| c as char)
|
||||
.collect();
|
||||
|
||||
rng.gen_range(32, 128);
|
||||
format!("{}={}", crate::connection::sasl::NONCE_ATTR, nonce)
|
||||
}
|
||||
|
||||
// Hi(str, salt, i):
|
||||
fn hi<'a>(s: &'a str, salt: &'a [u8], iter_count: u32) -> Result<[u8; 32]> {
|
||||
let mut mac = Hmac::<Sha256>::new_varkey(s.as_bytes())
|
||||
.map_err(|err| Error::connect(crate::PostgresDatabaseError::from(err)))?;
|
||||
|
||||
mac.update(&salt);
|
||||
mac.update(&1u32.to_be_bytes());
|
||||
|
||||
let mut u = mac.finalize().into_bytes();
|
||||
let mut hi = u;
|
||||
|
||||
for _ in 1..iter_count {
|
||||
let mut mac = Hmac::<Sha256>::new_varkey(s.as_bytes())
|
||||
.map_err(|err| Error::connect(crate::PostgresDatabaseError::from(err)))?;
|
||||
mac.update(u.as_slice());
|
||||
u = mac.finalize().into_bytes();
|
||||
hi = hi.iter().zip(u.iter()).map(|(&a, &b)| a ^ b).collect();
|
||||
}
|
||||
|
||||
Ok(hi.into())
|
||||
}
|
||||
@ -1,162 +0,0 @@
|
||||
//! Reads and writes packets to and from the PostgreSQL database server.
|
||||
//!
|
||||
//! The logic for serializing data structures into the packets is found
|
||||
//! mostly in `protocol/`.
|
||||
//!
|
||||
//! Packets in PostgreSQL are prefixed by 4 bytes.
|
||||
//! 3 for length (in LE) and a sequence id.
|
||||
//!
|
||||
//! Packets may only be as large as the communicated size in the initial
|
||||
//! `HandshakeResponse`. By default, SQLx configures its chunk size to 16M. Sending
|
||||
//! a larger payload is simply sending completely "full" packets, one after the
|
||||
//! other, with an increasing sequence id.
|
||||
//!
|
||||
//! In other words, when we sent data, we:
|
||||
//!
|
||||
//! - Split the data into "packets" of size `2 ** 24 - 1` bytes.
|
||||
//!
|
||||
//! - Prepend each packet with a **packet header**, consisting of the length of that packet,
|
||||
//! and the sequence number.
|
||||
//!
|
||||
//! https://dev.postgres.com/doc/internals/en/postgres-packet.html
|
||||
//!
|
||||
use std::convert::TryFrom;
|
||||
use std::fmt::Debug;
|
||||
|
||||
use bytes::{Buf, Bytes};
|
||||
use log::Level;
|
||||
use sqlx_core::io::{Deserialize, Serialize};
|
||||
use sqlx_core::{Result, Runtime};
|
||||
|
||||
use crate::protocol::{Message, MessageType, Notice, PgSeverity};
|
||||
use crate::PostgresConnection;
|
||||
|
||||
impl<Rt> PostgresConnection<Rt>
|
||||
where
|
||||
Rt: Runtime,
|
||||
{
|
||||
pub(super) fn write_packet<'ser, T>(&'ser mut self, packet: &T) -> Result<()>
|
||||
where
|
||||
T: Serialize<'ser, ()> + Debug,
|
||||
{
|
||||
log::trace!("write > {:?}", packet);
|
||||
|
||||
let buf = self.stream.buffer();
|
||||
packet.serialize_with(buf, ())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! read_packet {
|
||||
($(@$blocking:ident)? $self:ident) => {{
|
||||
loop {
|
||||
read_packet!($(@$blocking)? @stream $self, 0, 5);
|
||||
|
||||
// peek at the messaage type and payload size
|
||||
let r#type = MessageType::try_from(*$self.stream.get(0, 1))?;
|
||||
let size = (u32::from_be_bytes($self.stream.get(1, 4)) - 4) as usize;
|
||||
|
||||
read_packet!($(@$blocking)? @stream $self, 5, size);
|
||||
|
||||
// take the whole packet
|
||||
$self.stream.consume(5);
|
||||
let contents = $self.stream.take(size);
|
||||
|
||||
let message = Message { r#type, contents };
|
||||
|
||||
match message.r#type {
|
||||
MessageType::ErrorResponse => {
|
||||
// An error returned from the database server.
|
||||
// return Err(PgDatabaseError(message.decode()?).into());
|
||||
panic!("got error response");
|
||||
}
|
||||
|
||||
MessageType::NotificationResponse => {
|
||||
// if let Some(buffer) = &mut self.notifications {
|
||||
// let notification: Notification = message.decode()?;
|
||||
// let _ = self.write_packet(notification);
|
||||
|
||||
// continue;
|
||||
// }
|
||||
continue;
|
||||
}
|
||||
|
||||
MessageType::ParameterStatus => {
|
||||
// informs the frontend about the current (initial)
|
||||
// setting of backend parameters
|
||||
|
||||
// we currently have no use for that data so we promptly ignore this message
|
||||
continue;
|
||||
}
|
||||
|
||||
MessageType::NoticeResponse => {
|
||||
// do we need this to be more configurable?
|
||||
// if you are reading this comment and think so, open an issue
|
||||
|
||||
let notice: Notice = message.decode()?;
|
||||
|
||||
let lvl = match notice.severity() {
|
||||
PgSeverity::Fatal | PgSeverity::Panic | PgSeverity::Error => Level::Error,
|
||||
PgSeverity::Warning => Level::Warn,
|
||||
PgSeverity::Notice => Level::Info,
|
||||
PgSeverity::Debug => Level::Debug,
|
||||
PgSeverity::Info => Level::Trace,
|
||||
PgSeverity::Log => Level::Trace,
|
||||
};
|
||||
|
||||
if lvl <= log::STATIC_MAX_LEVEL && lvl <= log::max_level() {
|
||||
log::logger().log(
|
||||
&log::Record::builder()
|
||||
.args(format_args!("{}", notice.message()))
|
||||
.level(lvl)
|
||||
.module_path_static(Some("sqlx::postgres::notice"))
|
||||
.file_static(Some(file!()))
|
||||
.line(Some(line!()))
|
||||
.build(),
|
||||
);
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
_ => {}
|
||||
}
|
||||
|
||||
return T::deserialize_with(message.contents, ());
|
||||
}
|
||||
}};
|
||||
|
||||
(@blocking @stream $self:ident, $offset:expr, $n:expr) => {
|
||||
$self.stream.read($offset, $n)?;
|
||||
};
|
||||
|
||||
(@stream $self:ident, $offset:expr, $n:expr) => {
|
||||
$self.stream.read_async($offset, $n).await?;
|
||||
};
|
||||
}
|
||||
|
||||
impl<Rt> PostgresConnection<Rt>
|
||||
where
|
||||
Rt: Runtime,
|
||||
{
|
||||
#[cfg(feature = "async")]
|
||||
|
||||
pub(super) async fn read_packet_async<'de, T>(&'de mut self) -> Result<T>
|
||||
where
|
||||
T: Deserialize<'de, ()> + Debug,
|
||||
Rt: sqlx_core::Async,
|
||||
{
|
||||
read_packet!(self)
|
||||
}
|
||||
|
||||
#[cfg(feature = "blocking")]
|
||||
|
||||
pub(super) fn read_packet<'de, T>(&'de mut self) -> Result<T>
|
||||
where
|
||||
T: Deserialize<'de, ()> + Debug,
|
||||
Rt: sqlx_core::blocking::Runtime,
|
||||
{
|
||||
read_packet!(@blocking self)
|
||||
}
|
||||
}
|
||||
@ -1,15 +1,30 @@
|
||||
use sqlx_core::{Database, HasOutput, Runtime};
|
||||
use sqlx_core::database::{HasOutput, HasRawValue};
|
||||
use sqlx_core::Database;
|
||||
|
||||
use super::{PgColumn, PgOutput, PgQueryResult, PgRawValue, PgRow, PgTypeId, PgTypeInfo};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Postgres;
|
||||
|
||||
impl<Rt> Database<Rt> for Postgres
|
||||
where
|
||||
Rt: Runtime,
|
||||
{
|
||||
type Connection = super::PostgresConnection<Rt>;
|
||||
impl Database for Postgres {
|
||||
type Column = PgColumn;
|
||||
|
||||
type Row = PgRow;
|
||||
|
||||
type QueryResult = PgQueryResult;
|
||||
|
||||
type TypeInfo = PgTypeInfo;
|
||||
|
||||
type TypeId = PgTypeId;
|
||||
}
|
||||
|
||||
// 'x: execution
|
||||
impl<'x> HasOutput<'x> for Postgres {
|
||||
type Output = &'x mut Vec<u8>;
|
||||
type Output = PgOutput<'x>;
|
||||
}
|
||||
|
||||
// 'r: row
|
||||
impl<'r> HasRawValue<'r> for Postgres {
|
||||
type Database = Self;
|
||||
type RawValue = PgRawValue<'r>;
|
||||
}
|
||||
|
||||
@ -1,41 +0,0 @@
|
||||
use std::error::Error as StdError;
|
||||
use std::fmt::{self, Display, Formatter};
|
||||
|
||||
use sqlx_core::DatabaseError;
|
||||
|
||||
/// An error returned from the PostgreSQL database server.
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
#[derive(Debug)]
|
||||
pub struct PostgresDatabaseError(String);
|
||||
|
||||
impl PostgresDatabaseError {
|
||||
pub(crate) fn protocol(msg: String) -> PostgresDatabaseError {
|
||||
PostgresDatabaseError(msg)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<crypto_mac::InvalidKeyLength> for PostgresDatabaseError {
|
||||
fn from(err: crypto_mac::InvalidKeyLength) -> Self {
|
||||
PostgresDatabaseError::protocol(err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<crypto_mac::MacError> for PostgresDatabaseError {
|
||||
fn from(err: crypto_mac::MacError) -> Self {
|
||||
PostgresDatabaseError::protocol(err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl DatabaseError for PostgresDatabaseError {
|
||||
fn message(&self) -> &str {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for PostgresDatabaseError {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl StdError for PostgresDatabaseError {}
|
||||
@ -1,3 +0,0 @@
|
||||
mod write;
|
||||
|
||||
pub(crate) use write::PgBufMutExt;
|
||||
@ -1,52 +0,0 @@
|
||||
pub trait PgBufMutExt {
|
||||
fn write_length_prefixed<F>(&mut self, f: F)
|
||||
where
|
||||
F: FnOnce(&mut Vec<u8>);
|
||||
|
||||
fn write_statement_name(&mut self, id: u32);
|
||||
|
||||
fn write_portal_name(&mut self, id: Option<u32>);
|
||||
}
|
||||
|
||||
impl PgBufMutExt for Vec<u8> {
|
||||
// writes a length-prefixed message, this is used when encoding nearly all messages as postgres
|
||||
// wants us to send the length of the often-variable-sized messages up front
|
||||
fn write_length_prefixed<F>(&mut self, f: F)
|
||||
where
|
||||
F: FnOnce(&mut Vec<u8>),
|
||||
{
|
||||
// reserve space to write the prefixed length
|
||||
let offset = self.len();
|
||||
self.extend(&[0; 4]);
|
||||
|
||||
// write the main body of the message
|
||||
f(self);
|
||||
|
||||
// now calculate the size of what we wrote and set the length value
|
||||
let size = (self.len() - offset) as i32;
|
||||
self[offset..(offset + 4)].copy_from_slice(&size.to_be_bytes());
|
||||
}
|
||||
|
||||
// writes a statement name by ID
|
||||
#[inline]
|
||||
fn write_statement_name(&mut self, id: u32) {
|
||||
// N.B. if you change this don't forget to update it in ../describe.rs
|
||||
self.extend(b"sqlx_s_");
|
||||
|
||||
itoa::write(&mut *self, id).unwrap();
|
||||
|
||||
self.push(0);
|
||||
}
|
||||
|
||||
// writes a portal name by ID
|
||||
#[inline]
|
||||
fn write_portal_name(&mut self, id: Option<u32>) {
|
||||
if let Some(id) = id {
|
||||
self.extend(b"sqlx_p_");
|
||||
|
||||
itoa::write(&mut *self, id).unwrap();
|
||||
}
|
||||
|
||||
self.push(0);
|
||||
}
|
||||
}
|
||||
@ -1,6 +1,6 @@
|
||||
//! [PostgreSQL] database driver.
|
||||
//!
|
||||
//! [PostgreSQL]: https://www.postgres.com/
|
||||
//! [PostgreSQL]: https://www.postgresql.org/
|
||||
//!
|
||||
#![cfg_attr(doc_cfg, feature(doc_cfg))]
|
||||
#![cfg_attr(not(any(feature = "async", feature = "blocking")), allow(unused))]
|
||||
@ -17,18 +17,45 @@
|
||||
#![warn(clippy::use_self)]
|
||||
#![warn(clippy::useless_let_if_seq)]
|
||||
#![allow(clippy::doc_markdown)]
|
||||
#![allow(clippy::missing_errors_doc)]
|
||||
#![allow(clippy::missing_panics_doc)]
|
||||
|
||||
mod connection;
|
||||
use sqlx_core::Arguments;
|
||||
|
||||
#[macro_use]
|
||||
mod stream;
|
||||
|
||||
mod column;
|
||||
// mod connection;
|
||||
mod database;
|
||||
mod error;
|
||||
mod io;
|
||||
mod options;
|
||||
// mod error;
|
||||
// mod io;
|
||||
// mod options;
|
||||
mod output;
|
||||
mod protocol;
|
||||
mod query_result;
|
||||
// mod raw_statement;
|
||||
mod raw_value;
|
||||
mod row;
|
||||
// mod transaction;
|
||||
mod type_id;
|
||||
mod type_info;
|
||||
pub mod types;
|
||||
|
||||
#[cfg(test)]
|
||||
mod mock;
|
||||
// #[cfg(test)]
|
||||
// mod mock;
|
||||
|
||||
pub use connection::PostgresConnection;
|
||||
pub use column::PgColumn;
|
||||
// pub use connection::PgConnection;
|
||||
pub use database::Postgres;
|
||||
pub use error::PostgresDatabaseError;
|
||||
pub use options::PostgresConnectOptions;
|
||||
// pub use error::PgDatabaseError;
|
||||
// pub use options::PgConnectOptions;
|
||||
pub use output::PgOutput;
|
||||
pub use query_result::PgQueryResult;
|
||||
pub use raw_value::{PgRawValue, PgRawValueFormat};
|
||||
pub use row::PgRow;
|
||||
pub use type_id::PgTypeId;
|
||||
pub use type_info::PgTypeInfo;
|
||||
|
||||
// 'a: argument values
|
||||
pub type PgArguments<'a> = Arguments<'a, Postgres>;
|
||||
|
||||
@ -1,103 +0,0 @@
|
||||
use std::fmt::{self, Debug, Formatter};
|
||||
use std::marker::PhantomData;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use either::Either;
|
||||
use sqlx_core::{ConnectOptions, Runtime};
|
||||
|
||||
use crate::PostgresConnection;
|
||||
|
||||
mod builder;
|
||||
mod default;
|
||||
mod getters;
|
||||
mod parse;
|
||||
|
||||
// TODO: RSA Public Key (to avoid the key exchange for caching_sha2 and sha256 plugins)
|
||||
|
||||
/// Options which can be used to configure how a PostgreSQL connection is opened.
|
||||
///
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
pub struct PostgresConnectOptions<Rt>
|
||||
where
|
||||
Rt: Runtime,
|
||||
{
|
||||
runtime: PhantomData<Rt>,
|
||||
pub(crate) address: Either<(String, u16), PathBuf>,
|
||||
username: Option<String>,
|
||||
password: Option<String>,
|
||||
database: Option<String>,
|
||||
timezone: String,
|
||||
charset: String,
|
||||
}
|
||||
|
||||
impl<Rt> Clone for PostgresConnectOptions<Rt>
|
||||
where
|
||||
Rt: Runtime,
|
||||
{
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
runtime: PhantomData,
|
||||
address: self.address.clone(),
|
||||
username: self.username.clone(),
|
||||
password: self.password.clone(),
|
||||
database: self.database.clone(),
|
||||
timezone: self.timezone.clone(),
|
||||
charset: self.charset.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Rt> Debug for PostgresConnectOptions<Rt>
|
||||
where
|
||||
Rt: Runtime,
|
||||
{
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("PostgresConnectOptions")
|
||||
.field(
|
||||
"address",
|
||||
&self
|
||||
.address
|
||||
.as_ref()
|
||||
.map_left(|(host, port)| format!("{}:{}", host, port))
|
||||
.map_right(|socket| socket.display()),
|
||||
)
|
||||
.field("username", &self.username)
|
||||
.field("password", &self.password)
|
||||
.field("database", &self.database)
|
||||
.field("timezone", &self.timezone)
|
||||
.field("charset", &self.charset)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<Rt> ConnectOptions<Rt> for PostgresConnectOptions<Rt>
|
||||
where
|
||||
Rt: Runtime,
|
||||
{
|
||||
type Connection = PostgresConnection<Rt>;
|
||||
|
||||
#[cfg(feature = "async")]
|
||||
fn connect(&self) -> futures_util::future::BoxFuture<'_, sqlx_core::Result<Self::Connection>>
|
||||
where
|
||||
Self::Connection: Sized,
|
||||
Rt: sqlx_core::Async,
|
||||
{
|
||||
Box::pin(PostgresConnection::<Rt>::connect_async(self))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "blocking")]
|
||||
mod blocking {
|
||||
use sqlx_core::blocking::{ConnectOptions, Runtime};
|
||||
|
||||
use super::{PostgresConnectOptions, PostgresConnection};
|
||||
|
||||
impl<Rt: Runtime> ConnectOptions<Rt> for PostgresConnectOptions<Rt> {
|
||||
fn connect(&self) -> sqlx_core::Result<Self::Connection>
|
||||
where
|
||||
Self::Connection: Sized,
|
||||
{
|
||||
<PostgresConnection<Rt>>::connect(self)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,82 +0,0 @@
|
||||
use std::mem;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use either::Either;
|
||||
use sqlx_core::Runtime;
|
||||
|
||||
impl<Rt> super::PostgresConnectOptions<Rt>
|
||||
where
|
||||
Rt: Runtime,
|
||||
{
|
||||
/// Sets the hostname of the database server.
|
||||
///
|
||||
/// If the hostname begins with a slash (`/`), it is interpreted as the absolute path
|
||||
/// to a Unix domain socket file instead of a hostname of a server.
|
||||
///
|
||||
/// Defaults to `localhost`.
|
||||
///
|
||||
pub fn host(&mut self, host: impl AsRef<str>) -> &mut Self {
|
||||
let host = host.as_ref();
|
||||
|
||||
self.address = if host.starts_with('/') {
|
||||
Either::Right(PathBuf::from(&*host))
|
||||
} else {
|
||||
Either::Left((host.into(), self.get_port()))
|
||||
};
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the path of the Unix domain socket to connect to.
|
||||
///
|
||||
/// Overrides [`host()`](#method.host) and [`port()`](#method.port).
|
||||
///
|
||||
pub fn socket(&mut self, socket: impl AsRef<Path>) -> &mut Self {
|
||||
self.address = Either::Right(socket.as_ref().to_owned());
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the TCP port number of the database server.
|
||||
///
|
||||
/// Defaults to `3306`.
|
||||
///
|
||||
pub fn port(&mut self, port: u16) -> &mut Self {
|
||||
self.address = match self.address {
|
||||
Either::Right(_) => Either::Left(("localhost".to_owned(), port)),
|
||||
Either::Left((ref mut host, _)) => Either::Left((mem::take(host), port)),
|
||||
};
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the username to be used for authentication.
|
||||
// FIXME: Specify what happens when you do NOT set this
|
||||
pub fn username(&mut self, username: impl AsRef<str>) -> &mut Self {
|
||||
self.username = Some(username.as_ref().to_owned());
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the password to be used for authentication.
|
||||
pub fn password(&mut self, password: impl AsRef<str>) -> &mut Self {
|
||||
self.password = Some(password.as_ref().to_owned());
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the default database for the connection.
|
||||
pub fn database(&mut self, database: impl AsRef<str>) -> &mut Self {
|
||||
self.database = Some(database.as_ref().to_owned());
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the character set for the connection.
|
||||
pub fn charset(&mut self, charset: impl AsRef<str>) -> &mut Self {
|
||||
self.charset = charset.as_ref().to_owned();
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the timezone for the connection.
|
||||
pub fn timezone(&mut self, timezone: impl AsRef<str>) -> &mut Self {
|
||||
self.timezone = timezone.as_ref().to_owned();
|
||||
self
|
||||
}
|
||||
}
|
||||
@ -1,38 +0,0 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use either::Either;
|
||||
use sqlx_core::Runtime;
|
||||
|
||||
use crate::PostgresConnectOptions;
|
||||
|
||||
pub(crate) const HOST: &str = "localhost";
|
||||
pub(crate) const PORT: u16 = 3306;
|
||||
|
||||
impl<Rt> Default for PostgresConnectOptions<Rt>
|
||||
where
|
||||
Rt: Runtime,
|
||||
{
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
runtime: PhantomData,
|
||||
address: Either::Left((HOST.to_owned(), PORT)),
|
||||
username: None,
|
||||
password: None,
|
||||
database: None,
|
||||
charset: "utf8mb4".to_owned(),
|
||||
timezone: "utc".to_owned(),
|
||||
// todo: connect_timeout
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Rt> super::PostgresConnectOptions<Rt>
|
||||
where
|
||||
Rt: Runtime,
|
||||
{
|
||||
/// Creates a default set of options ready for configuration.
|
||||
#[must_use]
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
}
|
||||
@ -1,55 +0,0 @@
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use sqlx_core::Runtime;
|
||||
|
||||
use super::{default, PostgresConnectOptions};
|
||||
|
||||
impl<Rt: Runtime> PostgresConnectOptions<Rt> {
|
||||
/// Returns the hostname of the database server.
|
||||
#[must_use]
|
||||
pub fn get_host(&self) -> &str {
|
||||
self.address.as_ref().left().map_or(default::HOST, |(host, _)| &**host)
|
||||
}
|
||||
|
||||
/// Returns the TCP port number of the database server.
|
||||
#[must_use]
|
||||
pub fn get_port(&self) -> u16 {
|
||||
self.address.as_ref().left().map_or(default::PORT, |(_, port)| *port)
|
||||
}
|
||||
|
||||
/// Returns the path to the Unix domain socket, if one is configured.
|
||||
#[must_use]
|
||||
pub fn get_socket(&self) -> Option<&Path> {
|
||||
self.address.as_ref().right().map(PathBuf::as_path)
|
||||
}
|
||||
|
||||
/// Returns the default database name.
|
||||
#[must_use]
|
||||
pub fn get_database(&self) -> Option<&str> {
|
||||
self.database.as_deref()
|
||||
}
|
||||
|
||||
/// Returns the username to be used for authentication.
|
||||
#[must_use]
|
||||
pub fn get_username(&self) -> Option<&str> {
|
||||
self.username.as_deref()
|
||||
}
|
||||
|
||||
/// Returns the password to be used for authentication.
|
||||
#[must_use]
|
||||
pub fn get_password(&self) -> Option<&str> {
|
||||
self.password.as_deref()
|
||||
}
|
||||
|
||||
/// Returns the character set for the connection.
|
||||
#[must_use]
|
||||
pub fn get_charset(&self) -> &str {
|
||||
&self.charset
|
||||
}
|
||||
|
||||
/// Returns the timezone for the connection.
|
||||
#[must_use]
|
||||
pub fn get_timezone(&self) -> &str {
|
||||
&self.timezone
|
||||
}
|
||||
}
|
||||
@ -1,183 +0,0 @@
|
||||
use std::borrow::Cow;
|
||||
use std::str::FromStr;
|
||||
|
||||
use percent_encoding::percent_decode_str;
|
||||
use sqlx_core::{Error, Runtime};
|
||||
use url::Url;
|
||||
|
||||
use crate::PostgresConnectOptions;
|
||||
|
||||
impl<Rt> FromStr for PostgresConnectOptions<Rt>
|
||||
where
|
||||
Rt: Runtime,
|
||||
{
|
||||
type Err = Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
let url: Url =
|
||||
s.parse().map_err(|error| Error::configuration("for database URL", error))?;
|
||||
|
||||
if !matches!(url.scheme(), "postgres") {
|
||||
return Err(Error::configuration_msg(format!(
|
||||
"unsupported URL scheme {:?} for MySQL",
|
||||
url.scheme()
|
||||
)));
|
||||
}
|
||||
|
||||
let mut options = Self::new();
|
||||
|
||||
if let Some(host) = url.host_str() {
|
||||
options.host(percent_decode_str_utf8(host));
|
||||
}
|
||||
|
||||
if let Some(port) = url.port() {
|
||||
options.port(port);
|
||||
}
|
||||
|
||||
let username = url.username();
|
||||
if !username.is_empty() {
|
||||
options.username(percent_decode_str_utf8(username));
|
||||
}
|
||||
|
||||
if let Some(password) = url.password() {
|
||||
options.password(percent_decode_str_utf8(password));
|
||||
}
|
||||
|
||||
let mut path = url.path();
|
||||
|
||||
if path.starts_with('/') {
|
||||
path = &path[1..];
|
||||
}
|
||||
|
||||
if !path.is_empty() {
|
||||
options.database(path);
|
||||
}
|
||||
|
||||
for (key, value) in url.query_pairs() {
|
||||
match &*key {
|
||||
"user" | "username" => {
|
||||
options.username(value);
|
||||
}
|
||||
|
||||
"password" => {
|
||||
options.password(value);
|
||||
}
|
||||
|
||||
// ssl-mode compatibly with SQLx <= 0.5
|
||||
// sslmode compatibly with PostgreSQL
|
||||
// sslMode compatibly with JDBC MySQL
|
||||
// tls compatibly with Go MySQL [preferred]
|
||||
"ssl-mode" | "sslmode" | "sslMode" | "tls" => {
|
||||
todo!()
|
||||
}
|
||||
|
||||
"charset" => {
|
||||
options.charset(value);
|
||||
}
|
||||
|
||||
"timezone" => {
|
||||
options.timezone(value);
|
||||
}
|
||||
|
||||
"socket" => {
|
||||
options.socket(&*value);
|
||||
}
|
||||
|
||||
_ => {
|
||||
// ignore unknown connection parameters
|
||||
// fixme: should we error or warn here?
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(options)
|
||||
}
|
||||
}
|
||||
|
||||
// todo: this should probably go somewhere common
|
||||
fn percent_decode_str_utf8(value: &str) -> Cow<'_, str> {
|
||||
percent_decode_str(value).decode_utf8_lossy()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::path::Path;
|
||||
|
||||
use sqlx_core::mock::Mock;
|
||||
|
||||
use super::PostgresConnectOptions;
|
||||
|
||||
#[test]
|
||||
fn parse() {
|
||||
let url = "postgres://user:password@hostname:5432/database?timezone=system&charset=utf8";
|
||||
let options: PostgresConnectOptions<Mock> = url.parse().unwrap();
|
||||
|
||||
assert_eq!(options.get_username(), Some("user"));
|
||||
assert_eq!(options.get_password(), Some("password"));
|
||||
assert_eq!(options.get_host(), "hostname");
|
||||
assert_eq!(options.get_port(), 5432);
|
||||
assert_eq!(options.get_database(), Some("database"));
|
||||
assert_eq!(options.get_timezone(), "system");
|
||||
assert_eq!(options.get_charset(), "utf8");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_with_defaults() {
|
||||
let url = "postgres://";
|
||||
let options: PostgresConnectOptions<Mock> = url.parse().unwrap();
|
||||
|
||||
assert_eq!(options.get_username(), None);
|
||||
assert_eq!(options.get_password(), None);
|
||||
assert_eq!(options.get_host(), "localhost");
|
||||
assert_eq!(options.get_port(), 3306);
|
||||
assert_eq!(options.get_database(), None);
|
||||
assert_eq!(options.get_timezone(), "utc");
|
||||
assert_eq!(options.get_charset(), "utf8mb4");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_socket_from_query() {
|
||||
let url = "postgres://user:password@localhost/database?socket=/var/run/postgresd/postgresd.sock";
|
||||
let options: PostgresConnectOptions<Mock> = url.parse().unwrap();
|
||||
|
||||
assert_eq!(options.get_username(), Some("user"));
|
||||
assert_eq!(options.get_password(), Some("password"));
|
||||
assert_eq!(options.get_database(), Some("database"));
|
||||
assert_eq!(options.get_socket(), Some(Path::new("/var/run/postgresd/postgresd.sock")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_socket_from_host() {
|
||||
// socket path in host requires URL encoding - but does work
|
||||
let url = "postgres://user:password@%2Fvar%2Frun%2Fpostgresd%2Fpostgresd.sock/database";
|
||||
let options: PostgresConnectOptions<Mock> = url.parse().unwrap();
|
||||
|
||||
assert_eq!(options.get_username(), Some("user"));
|
||||
assert_eq!(options.get_password(), Some("password"));
|
||||
assert_eq!(options.get_database(), Some("database"));
|
||||
assert_eq!(options.get_socket(), Some(Path::new("/var/run/postgresd/postgresd.sock")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn fail_to_parse_non_postgres() {
|
||||
let url = "postgres://user:password@hostname:5432/database?timezone=system&charset=utf8";
|
||||
let _: PostgresConnectOptions<Mock> = url.parse().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_username_with_at_sign() {
|
||||
let url = "postgres://user@hostname:password@hostname:5432/database";
|
||||
let options: PostgresConnectOptions<Mock> = url.parse().unwrap();
|
||||
|
||||
assert_eq!(options.get_username(), Some("user@hostname"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_password_with_non_ascii_chars() {
|
||||
let url = "postgres://username:p@ssw0rd@hostname:5432/database";
|
||||
let options: PostgresConnectOptions<Mock> = url.parse().unwrap();
|
||||
|
||||
assert_eq!(options.get_password(), Some("p@ssw0rd"));
|
||||
}
|
||||
}
|
||||
15
sqlx-postgres/src/output.rs
Normal file
15
sqlx-postgres/src/output.rs
Normal file
@ -0,0 +1,15 @@
|
||||
// 'x: execution
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
pub struct PgOutput<'x> {
|
||||
buffer: &'x mut Vec<u8>,
|
||||
}
|
||||
|
||||
impl<'x> PgOutput<'x> {
|
||||
pub(crate) fn new(buffer: &'x mut Vec<u8>) -> Self {
|
||||
Self { buffer }
|
||||
}
|
||||
|
||||
pub(crate) fn buffer(&mut self) -> &mut Vec<u8> {
|
||||
self.buffer
|
||||
}
|
||||
}
|
||||
@ -1,111 +1,3 @@
|
||||
use std::convert::TryFrom;
|
||||
mod message;
|
||||
|
||||
use bytes::{Buf, Bytes};
|
||||
use sqlx_core::io::Deserialize;
|
||||
use sqlx_core::Error;
|
||||
use sqlx_core::Result;
|
||||
|
||||
mod authentication;
|
||||
mod backend_key_data;
|
||||
mod close;
|
||||
mod notification;
|
||||
mod password;
|
||||
mod ready_for_query;
|
||||
mod response;
|
||||
mod sasl;
|
||||
mod startup;
|
||||
mod terminate;
|
||||
|
||||
pub(crate) use authentication::{Authentication, AuthenticationSasl};
|
||||
pub(crate) use backend_key_data::BackendKeyData;
|
||||
pub(crate) use close::Close;
|
||||
pub(crate) use notification::Notification;
|
||||
pub(crate) use password::Password;
|
||||
pub(crate) use ready_for_query::ReadyForQuery;
|
||||
pub(crate) use response::{Notice, PgSeverity};
|
||||
pub(crate) use sasl::{SaslInitialResponse, SaslResponse};
|
||||
pub(crate) use startup::Startup;
|
||||
pub(crate) use terminate::Terminate;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialOrd, PartialEq, Eq)]
|
||||
#[repr(u8)]
|
||||
pub enum MessageType {
|
||||
ParseComplete = b'1',
|
||||
BindComplete = b'2',
|
||||
CloseComplete = b'3',
|
||||
CommandComplete = b'C',
|
||||
DataRow = b'D',
|
||||
ErrorResponse = b'E',
|
||||
EmptyQueryResponse = b'I',
|
||||
NotificationResponse = b'A',
|
||||
BackendKeyData = b'K',
|
||||
NoticeResponse = b'N',
|
||||
Authentication = b'R',
|
||||
ParameterStatus = b'S',
|
||||
RowDescription = b'T',
|
||||
ReadyForQuery = b'Z',
|
||||
NoData = b'n',
|
||||
PortalSuspended = b's',
|
||||
ParameterDescription = b't',
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Message {
|
||||
pub r#type: MessageType,
|
||||
pub contents: Bytes,
|
||||
}
|
||||
|
||||
impl Message {
|
||||
#[inline]
|
||||
pub fn decode<'de, T>(self) -> Result<T>
|
||||
where
|
||||
T: Deserialize<'de, ()>,
|
||||
{
|
||||
T::deserialize_with(self.contents, ())
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<u8> for MessageType {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(v: u8) -> Result<Self> {
|
||||
// https://www.postgresql.org/docs/current/protocol-message-formats.html
|
||||
|
||||
Ok(match v {
|
||||
b'1' => MessageType::ParseComplete,
|
||||
b'2' => MessageType::BindComplete,
|
||||
b'3' => MessageType::CloseComplete,
|
||||
b'C' => MessageType::CommandComplete,
|
||||
b'D' => MessageType::DataRow,
|
||||
b'E' => MessageType::ErrorResponse,
|
||||
b'I' => MessageType::EmptyQueryResponse,
|
||||
b'A' => MessageType::NotificationResponse,
|
||||
b'K' => MessageType::BackendKeyData,
|
||||
b'N' => MessageType::NoticeResponse,
|
||||
b'R' => MessageType::Authentication,
|
||||
b'S' => MessageType::ParameterStatus,
|
||||
b'T' => MessageType::RowDescription,
|
||||
b'Z' => MessageType::ReadyForQuery,
|
||||
b'n' => MessageType::NoData,
|
||||
b's' => MessageType::PortalSuspended,
|
||||
b't' => MessageType::ParameterDescription,
|
||||
|
||||
_ => {
|
||||
return Err(Error::configuration_msg(format!(
|
||||
"unknown message type: {:?}",
|
||||
v as char
|
||||
)));
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserialize<'_, ()> for Message {
|
||||
fn deserialize_with(mut buf: Bytes, _: ()) -> Result<Self> {
|
||||
let r#type = MessageType::try_from(buf.get_u8())?;
|
||||
let size = buf.get_u32() - 4;
|
||||
let contents = buf.split_to(size as usize);
|
||||
|
||||
Ok(Message { r#type, contents })
|
||||
}
|
||||
}
|
||||
pub(crate) use message::{BackendMessage, BackendMessageType};
|
||||
|
||||
@ -1,204 +0,0 @@
|
||||
use bytes::{Buf, Bytes};
|
||||
use memchr::memchr;
|
||||
use sqlx_core::io::Deserialize;
|
||||
use sqlx_core::Error;
|
||||
use sqlx_core::Result;
|
||||
|
||||
// On startup, the server sends an appropriate authentication request message,
|
||||
// to which the frontend must reply with an appropriate authentication
|
||||
// response message (such as a password).
|
||||
|
||||
// For all authentication methods except GSSAPI, SSPI and SASL, there is at
|
||||
// most one request and one response. In some methods, no response at all is
|
||||
// needed from the frontend, and so no authentication request occurs.
|
||||
|
||||
// For GSSAPI, SSPI and SASL, multiple exchanges of packets may
|
||||
// be needed to complete the authentication.
|
||||
|
||||
// <https://www.postgresql.org/docs/devel/protocol-flow.html#id-1.10.5.7.3>
|
||||
// <https://www.postgresql.org/docs/devel/protocol-message-formats.html>
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Authentication {
|
||||
/// The authentication exchange is successfully completed.
|
||||
Ok,
|
||||
|
||||
/// The frontend must now send a [PasswordMessage] containing the
|
||||
/// password in clear-text form.
|
||||
CleartextPassword,
|
||||
|
||||
/// The frontend must now send a [PasswordMessage] containing the
|
||||
/// password (with user name) encrypted via MD5, then encrypted
|
||||
/// again using the 4-byte random salt.
|
||||
Md5Password(AuthenticationMd5Password),
|
||||
|
||||
/// The frontend must now initiate a SASL negotiation,
|
||||
/// using one of the SASL mechanisms listed in the message.
|
||||
///
|
||||
/// The frontend will send a [SaslInitialResponse] with the name
|
||||
/// of the selected mechanism, and the first part of the SASL
|
||||
/// data stream in response to this.
|
||||
///
|
||||
/// If further messages are needed, the server will
|
||||
/// respond with [Authentication::SaslContinue].
|
||||
Sasl(AuthenticationSasl),
|
||||
|
||||
/// This message contains challenge data from the previous step of SASL negotiation.
|
||||
///
|
||||
/// The frontend must respond with a [SaslResponse] message.
|
||||
SaslContinue(AuthenticationSaslContinue),
|
||||
|
||||
/// SASL authentication has completed with additional mechanism-specific
|
||||
/// data for the client.
|
||||
///
|
||||
/// The server will next send [Authentication::Ok] to
|
||||
/// indicate successful authentication.
|
||||
SaslFinal(AuthenticationSaslFinal),
|
||||
}
|
||||
|
||||
impl Deserialize<'_, ()> for Authentication {
|
||||
fn deserialize_with(mut buf: Bytes, _: ()) -> Result<Self> {
|
||||
Ok(match buf.get_u32() {
|
||||
0 => Authentication::Ok,
|
||||
|
||||
3 => Authentication::CleartextPassword,
|
||||
|
||||
5 => {
|
||||
let mut salt = [0; 4];
|
||||
buf.copy_to_slice(&mut salt);
|
||||
|
||||
Authentication::Md5Password(AuthenticationMd5Password { salt })
|
||||
}
|
||||
|
||||
10 => Authentication::Sasl(AuthenticationSasl(buf)),
|
||||
|
||||
11 => {
|
||||
Authentication::SaslContinue(AuthenticationSaslContinue::deserialize_with(buf, ())?)
|
||||
}
|
||||
|
||||
12 => Authentication::SaslFinal(AuthenticationSaslFinal::deserialize_with(buf, ())?),
|
||||
|
||||
ty => {
|
||||
return Err(Error::configuration_msg(format!(
|
||||
"unknown authentication method: {}",
|
||||
ty
|
||||
)));
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Body of [Authentication::Md5Password].
|
||||
#[derive(Debug)]
|
||||
pub struct AuthenticationMd5Password {
|
||||
pub salt: [u8; 4],
|
||||
}
|
||||
|
||||
/// Body of [Authentication::Sasl].
|
||||
#[derive(Debug)]
|
||||
pub struct AuthenticationSasl(Bytes);
|
||||
|
||||
impl AuthenticationSasl {
|
||||
#[inline]
|
||||
pub fn mechanisms(&self) -> SaslMechanisms<'_> {
|
||||
SaslMechanisms(&self.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// An iterator over the SASL authentication mechanisms provided by the server.
|
||||
pub struct SaslMechanisms<'a>(&'a [u8]);
|
||||
|
||||
impl<'a> Iterator for SaslMechanisms<'a> {
|
||||
type Item = &'a str;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if !self.0.is_empty() && self.0[0] == b'\0' {
|
||||
return None;
|
||||
}
|
||||
|
||||
#[allow(unsafe_code)]
|
||||
let mechanism = memchr(b'\0', self.0)
|
||||
// UNSAFE: Postgres is expecte to return a valid UTF-8 string here
|
||||
.and_then(|nul| Some(unsafe { std::str::from_utf8_unchecked(&self.0[..nul]) }))?;
|
||||
|
||||
self.0 = &self.0[(mechanism.len() + 1)..];
|
||||
|
||||
Some(mechanism)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AuthenticationSaslContinue {
|
||||
pub salt: Vec<u8>,
|
||||
pub iterations: u32,
|
||||
pub nonce: String,
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
impl Deserialize<'_, ()> for AuthenticationSaslContinue {
|
||||
fn deserialize_with(mut buf: Bytes, _: ()) -> Result<Self> {
|
||||
let mut iterations: u32 = 4096;
|
||||
let mut salt = Vec::new();
|
||||
let mut nonce = Bytes::new();
|
||||
|
||||
// [Example]
|
||||
// r=/z+giZiTxAH7r8sNAeHr7cvpqV3uo7G/bJBIJO3pjVM7t3ng,s=4UV68bIkC8f9/X8xH7aPhg==,i=4096
|
||||
|
||||
for item in buf.split(|b| *b == b',') {
|
||||
let key = item[0];
|
||||
let value = &item[2..];
|
||||
|
||||
match key {
|
||||
b'r' => {
|
||||
nonce = buf.slice_ref(value);
|
||||
}
|
||||
|
||||
b'i' => {
|
||||
iterations = atoi::atoi(value).unwrap_or(4096);
|
||||
}
|
||||
|
||||
b's' => {
|
||||
// TODO: Map error correctly
|
||||
salt = base64::decode(value).unwrap();
|
||||
}
|
||||
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unsafe_code)]
|
||||
Ok(Self {
|
||||
iterations,
|
||||
salt,
|
||||
|
||||
// UNSAFE: Postgres is expected to return a valid UTF-8 string here
|
||||
nonce: unsafe { String::from_utf8_unchecked((*nonce).to_vec()) },
|
||||
|
||||
// UNSAFE: Postgres is expected to return a valid UTF-8 string here
|
||||
message: unsafe { String::from_utf8_unchecked((*buf).to_vec()) },
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AuthenticationSaslFinal {
|
||||
pub verifier: Vec<u8>,
|
||||
}
|
||||
|
||||
impl Deserialize<'_, ()> for AuthenticationSaslFinal {
|
||||
fn deserialize_with(mut buf: Bytes, _: ()) -> Result<Self> {
|
||||
let mut verifier = Vec::new();
|
||||
|
||||
for item in buf.split(|b| *b == b',') {
|
||||
let key = item[0];
|
||||
let value = &item[2..];
|
||||
|
||||
if let b'v' = key {
|
||||
// TODO: Map error correctly
|
||||
verifier = base64::decode(value).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self { verifier })
|
||||
}
|
||||
}
|
||||
@ -1,25 +0,0 @@
|
||||
use byteorder::{BigEndian, ByteOrder};
|
||||
use bytes::Bytes;
|
||||
use sqlx_core::io::Deserialize;
|
||||
use sqlx_core::Error;
|
||||
use sqlx_core::Result;
|
||||
|
||||
/// Contains cancellation key data. The frontend must save these values if it
|
||||
/// wishes to be able to issue `CancelRequest` messages later.
|
||||
#[derive(Debug)]
|
||||
pub struct BackendKeyData {
|
||||
/// The process ID of this database.
|
||||
pub process_id: u32,
|
||||
|
||||
/// The secret key of this database.
|
||||
pub secret_key: u32,
|
||||
}
|
||||
|
||||
impl Deserialize<'_, ()> for BackendKeyData {
|
||||
fn deserialize_with(mut buf: Bytes, _: ()) -> Result<Self> {
|
||||
let process_id = BigEndian::read_u32(&buf);
|
||||
let secret_key = BigEndian::read_u32(&buf[4..]);
|
||||
|
||||
Ok(Self { process_id, secret_key })
|
||||
}
|
||||
}
|
||||
@ -1,36 +0,0 @@
|
||||
use sqlx_core::io::Serialize;
|
||||
use sqlx_core::Result;
|
||||
|
||||
use crate::io::PgBufMutExt;
|
||||
|
||||
const CLOSE_PORTAL: u8 = b'P';
|
||||
const CLOSE_STATEMENT: u8 = b'S';
|
||||
|
||||
#[derive(Debug)]
|
||||
#[allow(dead_code)]
|
||||
pub enum Close {
|
||||
Statement(u32),
|
||||
Portal(u32),
|
||||
}
|
||||
|
||||
impl Serialize<'_, ()> for Close {
|
||||
fn serialize_with(&self, buf: &mut Vec<u8>, _: ()) -> Result<()> {
|
||||
// 15 bytes for 1-digit statement/portal IDs
|
||||
buf.reserve(20);
|
||||
buf.push(b'C');
|
||||
|
||||
buf.write_length_prefixed(|buf| match self {
|
||||
Close::Statement(id) => {
|
||||
buf.push(CLOSE_STATEMENT);
|
||||
buf.write_statement_name(*id);
|
||||
}
|
||||
|
||||
Close::Portal(id) => {
|
||||
buf.push(CLOSE_PORTAL);
|
||||
buf.write_portal_name(Some(*id));
|
||||
}
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@ -1,18 +0,0 @@
|
||||
use sqlx_core::io::Serialize;
|
||||
use sqlx_core::Result;
|
||||
|
||||
// The Flush message does not cause any specific output to be generated,
|
||||
// but forces the backend to deliver any data pending in its output buffers.
|
||||
|
||||
// A Flush must be sent after any extended-query command except Sync, if the
|
||||
// frontend wishes to examine the results of that command before issuing more commands.
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Flush;
|
||||
|
||||
impl Serialize<'_, ()> for Flush {
|
||||
fn serialize_with(&self, buf: &mut Vec<u8>, _: ()) -> Result<()> {
|
||||
buf.push(b'H');
|
||||
buf.extend(&4_i32.to_be_bytes());
|
||||
}
|
||||
}
|
||||
93
sqlx-postgres/src/protocol/message.rs
Normal file
93
sqlx-postgres/src/protocol/message.rs
Normal file
@ -0,0 +1,93 @@
|
||||
use std::convert::TryFrom;
|
||||
use std::fmt::Debug;
|
||||
|
||||
use bytes::Bytes;
|
||||
use sqlx_core::io::Deserialize;
|
||||
use sqlx_core::{Error, Result};
|
||||
|
||||
/// Type of the *incoming* message.
|
||||
///
|
||||
/// Postgres does use the same message format for client and server messages but we are only
|
||||
/// interested in messages from the backend.
|
||||
///
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
#[repr(u8)]
|
||||
pub(crate) enum BackendMessageType {
|
||||
ParseComplete = b'1',
|
||||
BindComplete = b'2',
|
||||
CloseComplete = b'3',
|
||||
CommandComplete = b'C',
|
||||
DataRow = b'D',
|
||||
ErrorResponse = b'E',
|
||||
EmptyQueryResponse = b'I',
|
||||
NotificationResponse = b'A',
|
||||
BackendKeyData = b'K',
|
||||
NoticeResponse = b'N',
|
||||
Authentication = b'R',
|
||||
ParameterStatus = b'S',
|
||||
RowDescription = b'T',
|
||||
ReadyForQuery = b'Z',
|
||||
NoData = b'n',
|
||||
PortalSuspended = b's',
|
||||
ParameterDescription = b't',
|
||||
CopyInResponse = b'G',
|
||||
CopyOutResponse = b'H',
|
||||
CopyBothResponse = b'W',
|
||||
CopyData = b'd',
|
||||
CopyDone = b'c',
|
||||
}
|
||||
|
||||
impl TryFrom<u8> for BackendMessageType {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(ty: u8) -> Result<Self> {
|
||||
Ok(match ty {
|
||||
b'1' => Self::ParseComplete,
|
||||
b'2' => Self::BindComplete,
|
||||
b'3' => Self::CloseComplete,
|
||||
b'C' => Self::CommandComplete,
|
||||
b'D' => Self::DataRow,
|
||||
b'E' => Self::ErrorResponse,
|
||||
b'I' => Self::EmptyQueryResponse,
|
||||
b'A' => Self::NotificationResponse,
|
||||
b'K' => Self::BackendKeyData,
|
||||
b'N' => Self::NoticeResponse,
|
||||
b'R' => Self::Authentication,
|
||||
b'S' => Self::ParameterStatus,
|
||||
b'T' => Self::RowDescription,
|
||||
b'Z' => Self::ReadyForQuery,
|
||||
b'n' => Self::NoData,
|
||||
b's' => Self::PortalSuspended,
|
||||
b't' => Self::ParameterDescription,
|
||||
b'G' => Self::CopyInResponse,
|
||||
b'H' => Self::CopyOutResponse,
|
||||
b'W' => Self::CopyBothResponse,
|
||||
b'd' => Self::CopyData,
|
||||
b'c' => Self::CopyDone,
|
||||
|
||||
_ => {
|
||||
todo!("protocol unexpected data error")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct BackendMessage {
|
||||
pub(crate) r#type: BackendMessageType,
|
||||
pub(crate) contents: Bytes,
|
||||
}
|
||||
|
||||
impl BackendMessage {
|
||||
#[inline]
|
||||
pub(crate) fn deserialize<'de, T>(self) -> Result<T>
|
||||
where
|
||||
T: Deserialize<'de> + Debug,
|
||||
{
|
||||
let packet = T::deserialize(self.contents)?;
|
||||
|
||||
log::trace!("read > {:?}", packet);
|
||||
|
||||
Ok(packet)
|
||||
}
|
||||
}
|
||||
@ -1,28 +0,0 @@
|
||||
use bytes::{Buf, Bytes};
|
||||
use bytestring::ByteString;
|
||||
use sqlx_core::io::BufExt;
|
||||
use sqlx_core::io::Deserialize;
|
||||
use sqlx_core::Result;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Notification {
|
||||
pub(crate) process_id: u32,
|
||||
pub(crate) channel: ByteString,
|
||||
pub(crate) payload: ByteString,
|
||||
}
|
||||
|
||||
impl Deserialize<'_, ()> for Notification {
|
||||
fn deserialize_with(mut buf: Bytes, _: ()) -> Result<Self> {
|
||||
let process_id = buf.get_u32();
|
||||
|
||||
// UNSAFE: This message will not be read.
|
||||
#[allow(unsafe_code)]
|
||||
let channel = unsafe { buf.get_str_nul_unchecked()? };
|
||||
|
||||
// UNSAFE: This message will not be read.
|
||||
#[allow(unsafe_code)]
|
||||
let payload = unsafe { buf.get_str_nul_unchecked()? };
|
||||
|
||||
Ok(Self { process_id, channel, payload })
|
||||
}
|
||||
}
|
||||
@ -1,67 +0,0 @@
|
||||
use std::fmt::Write;
|
||||
|
||||
use md5::{Digest, Md5};
|
||||
use sqlx_core::io::Serialize;
|
||||
use sqlx_core::io::WriteExt;
|
||||
use sqlx_core::Result;
|
||||
|
||||
use crate::io::PgBufMutExt;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Password<'a> {
|
||||
Cleartext(&'a str),
|
||||
|
||||
Md5 { password: &'a str, username: &'a str, salt: [u8; 4] },
|
||||
}
|
||||
|
||||
impl Password<'_> {
|
||||
#[inline]
|
||||
fn len(&self) -> usize {
|
||||
match self {
|
||||
Password::Cleartext(s) => s.len() + 5,
|
||||
Password::Md5 { .. } => 35 + 5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize<'_, ()> for Password<'_> {
|
||||
fn serialize_with(&self, buf: &mut Vec<u8>, _: ()) -> Result<()> {
|
||||
buf.reserve(1 + 4 + self.len());
|
||||
buf.push(b'p');
|
||||
|
||||
buf.write_length_prefixed(|buf| {
|
||||
match self {
|
||||
Password::Cleartext(password) => {
|
||||
buf.write_str_nul(password);
|
||||
}
|
||||
|
||||
Password::Md5 { username, password, salt } => {
|
||||
// The actual `PasswordMessage` can be comwriteed in SQL as
|
||||
// `concat('md5', md5(concat(md5(concat(password, username)), random-salt)))`.
|
||||
|
||||
// Keep in mind the md5() function returns its result as a hex string.
|
||||
|
||||
let mut hasher = Md5::new();
|
||||
|
||||
hasher.update(password);
|
||||
hasher.update(username);
|
||||
|
||||
let mut output = String::with_capacity(35);
|
||||
|
||||
let _ = write!(output, "{:x}", hasher.finalize_reset());
|
||||
|
||||
hasher.update(&output);
|
||||
hasher.update(salt);
|
||||
|
||||
output.clear();
|
||||
|
||||
let _ = write!(output, "md5{:x}", hasher.finalize());
|
||||
|
||||
buf.write_str_nul(&output);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@ -1,51 +0,0 @@
|
||||
use std::convert::TryFrom;
|
||||
|
||||
use bytes::Bytes;
|
||||
use sqlx_core::io::Deserialize;
|
||||
use sqlx_core::Error;
|
||||
use sqlx_core::Result;
|
||||
|
||||
#[derive(Debug)]
|
||||
#[repr(u8)]
|
||||
pub(crate) enum TransactionStatus {
|
||||
/// Not in a transaction block.
|
||||
Idle = b'I',
|
||||
|
||||
/// In a transaction block.
|
||||
Transaction = b'T',
|
||||
|
||||
/// In a _failed_ transaction block. Queries will be rejected until block is ended.
|
||||
Error = b'E',
|
||||
}
|
||||
|
||||
impl TryFrom<u8> for TransactionStatus {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(value: u8) -> Result<Self> {
|
||||
match value {
|
||||
b'I' => Ok(TransactionStatus::Idle),
|
||||
b'T' => Ok(TransactionStatus::Transaction),
|
||||
b'E' => Ok(TransactionStatus::Error),
|
||||
|
||||
status => {
|
||||
return Err(Error::configuration_msg(format!(
|
||||
"unknown transaction status: {:?}",
|
||||
status as char,
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct ReadyForQuery {
|
||||
pub transaction_status: TransactionStatus,
|
||||
}
|
||||
|
||||
impl Deserialize<'_, ()> for ReadyForQuery {
|
||||
fn deserialize_with(mut buf: Bytes, _: ()) -> Result<Self> {
|
||||
let transaction_status = TransactionStatus::try_from(buf[0])?;
|
||||
|
||||
Ok(Self { transaction_status })
|
||||
}
|
||||
}
|
||||
@ -1,190 +0,0 @@
|
||||
use std::str::from_utf8;
|
||||
|
||||
use bytes::Bytes;
|
||||
use memchr::memchr;
|
||||
use sqlx_core::io::Deserialize;
|
||||
use sqlx_core::Error;
|
||||
use sqlx_core::Result;
|
||||
|
||||
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
|
||||
#[repr(u8)]
|
||||
pub enum PgSeverity {
|
||||
Panic,
|
||||
Fatal,
|
||||
Error,
|
||||
Warning,
|
||||
Notice,
|
||||
Debug,
|
||||
Info,
|
||||
Log,
|
||||
}
|
||||
|
||||
impl PgSeverity {
|
||||
#[inline]
|
||||
pub fn is_error(self) -> bool {
|
||||
matches!(self, Self::Panic | Self::Fatal | Self::Error)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::convert::TryFrom<&str> for PgSeverity {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(s: &str) -> Result<PgSeverity> {
|
||||
let result = match s {
|
||||
"PANIC" => PgSeverity::Panic,
|
||||
"FATAL" => PgSeverity::Fatal,
|
||||
"ERROR" => PgSeverity::Error,
|
||||
"WARNING" => PgSeverity::Warning,
|
||||
"NOTICE" => PgSeverity::Notice,
|
||||
"DEBUG" => PgSeverity::Debug,
|
||||
"INFO" => PgSeverity::Info,
|
||||
"LOG" => PgSeverity::Log,
|
||||
|
||||
severity => {
|
||||
return Err(Error::configuration_msg(format!("unknown severity: {:?}", severity)));
|
||||
}
|
||||
};
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Notice {
|
||||
storage: Bytes,
|
||||
severity: PgSeverity,
|
||||
message: (u16, u16),
|
||||
code: (u16, u16),
|
||||
}
|
||||
|
||||
impl Notice {
|
||||
#[inline]
|
||||
pub fn severity(&self) -> PgSeverity {
|
||||
self.severity
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn code(&self) -> &str {
|
||||
self.get_cached_str(self.code)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn message(&self) -> &str {
|
||||
self.get_cached_str(self.message)
|
||||
}
|
||||
|
||||
// Field descriptions available here:
|
||||
// https://www.postgresql.org/docs/current/protocol-error-fields.html
|
||||
|
||||
#[inline]
|
||||
pub fn get(&self, ty: u8) -> Option<&str> {
|
||||
self.get_raw(ty).and_then(|v| from_utf8(v).ok())
|
||||
}
|
||||
|
||||
pub fn get_raw(&self, ty: u8) -> Option<&[u8]> {
|
||||
self.fields()
|
||||
.filter(|(field, _)| *field == ty)
|
||||
.map(|(_, (start, end))| &self.storage[start as usize..end as usize])
|
||||
.next()
|
||||
}
|
||||
}
|
||||
|
||||
impl Notice {
|
||||
#[inline]
|
||||
fn fields(&self) -> Fields<'_> {
|
||||
Fields { storage: &self.storage, offset: 0 }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn get_cached_str(&self, cache: (u16, u16)) -> &str {
|
||||
// unwrap: this cannot fail at this stage
|
||||
from_utf8(&self.storage[cache.0 as usize..cache.1 as usize]).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserialize<'_, ()> for Notice {
|
||||
fn deserialize_with(mut buf: Bytes, _: ()) -> Result<Self> {
|
||||
// In order to support PostgreSQL 9.5 and older we need to parse the localized S field.
|
||||
// Newer versions additionally come with the V field that is guaranteed to be in English.
|
||||
// We thus read both versions and prefer the unlocalized one if available.
|
||||
const DEFAULT_SEVERITY: PgSeverity = PgSeverity::Log;
|
||||
let mut severity_v = None;
|
||||
let mut severity_s = None;
|
||||
let mut message = (0, 0);
|
||||
let mut code = (0, 0);
|
||||
|
||||
// we cache the three always present fields
|
||||
// this enables to keep the access time down for the fields most likely accessed
|
||||
|
||||
let fields = Fields { storage: &buf, offset: 0 };
|
||||
|
||||
for (field, v) in fields {
|
||||
if message.0 != 0 && code.0 != 0 {
|
||||
// stop iterating when we have the 3 fields we were looking for
|
||||
// we assume V (severity) was the first field as it should be
|
||||
break;
|
||||
}
|
||||
|
||||
use std::convert::TryInto;
|
||||
match field {
|
||||
b'S' => {
|
||||
// Discard potential errors, because the message might be localized
|
||||
severity_s =
|
||||
from_utf8(&buf[v.0 as usize..v.1 as usize]).unwrap().try_into().ok();
|
||||
}
|
||||
|
||||
b'V' => {
|
||||
// Propagate errors here, because V is not localized and thus we are missing a possible
|
||||
// variant.
|
||||
severity_v =
|
||||
Some(from_utf8(&buf[v.0 as usize..v.1 as usize]).unwrap().try_into()?);
|
||||
}
|
||||
|
||||
b'M' => {
|
||||
message = v;
|
||||
}
|
||||
|
||||
b'C' => {
|
||||
code = v;
|
||||
}
|
||||
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
severity: severity_v.or(severity_s).unwrap_or(DEFAULT_SEVERITY),
|
||||
message,
|
||||
code,
|
||||
storage: buf,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// An iterator over each field in the Error (or Notice) response.
|
||||
struct Fields<'a> {
|
||||
storage: &'a [u8],
|
||||
offset: u16,
|
||||
}
|
||||
|
||||
impl<'a> Iterator for Fields<'a> {
|
||||
type Item = (u8, (u16, u16));
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
// The fields in the response body are sequentially stored as [tag][string],
|
||||
// ending in a final, additional [nul]
|
||||
|
||||
let ty = self.storage[self.offset as usize];
|
||||
|
||||
if ty == 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let nul = memchr(b'\0', &self.storage[(self.offset + 1) as usize..])? as u16;
|
||||
let offset = self.offset;
|
||||
|
||||
self.offset += nul + 2;
|
||||
|
||||
Some((ty, (offset + 1, offset + nul + 1)))
|
||||
}
|
||||
}
|
||||
@ -1,40 +0,0 @@
|
||||
use sqlx_core::io::Serialize;
|
||||
use sqlx_core::io::WriteExt;
|
||||
use sqlx_core::Result;
|
||||
|
||||
use crate::io::PgBufMutExt;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct SaslInitialResponse {
|
||||
pub response: String,
|
||||
pub plus: bool,
|
||||
}
|
||||
|
||||
impl Serialize<'_, ()> for SaslInitialResponse {
|
||||
fn serialize_with(&self, buf: &mut Vec<u8>, _: ()) -> Result<()> {
|
||||
buf.push(b'p');
|
||||
buf.write_length_prefixed(|buf| {
|
||||
// name of the SASL authentication mechanism that the client selected
|
||||
buf.write_str_nul(if self.plus { "SCRAM-SHA-256-PLUS" } else { "SCRAM-SHA-256" });
|
||||
|
||||
buf.extend(&(self.response.as_bytes().len() as i32).to_be_bytes());
|
||||
buf.extend(self.response.as_bytes());
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct SaslResponse<'a>(pub &'a str);
|
||||
|
||||
impl Serialize<'_, ()> for SaslResponse<'_> {
|
||||
fn serialize_with(&self, buf: &mut Vec<u8>, _: ()) -> Result<()> {
|
||||
buf.push(b'p');
|
||||
buf.write_length_prefixed(|buf| {
|
||||
buf.extend(self.0.as_bytes());
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@ -1,64 +0,0 @@
|
||||
use sqlx_core::io::Serialize;
|
||||
use sqlx_core::io::WriteExt;
|
||||
use sqlx_core::Result;
|
||||
|
||||
use crate::io::PgBufMutExt;
|
||||
|
||||
// To begin a session, a frontend opens a connection to the server and sends a startup message.
|
||||
// This message includes the names of the user and of the database the user wants to connect to;
|
||||
// it also identifies the particular protocol version to be used.
|
||||
|
||||
// Optionally, the startup message can include additional settings for run-time parameters.
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Startup<'a> {
|
||||
/// The database user name to connect as. Required; there is no default.
|
||||
pub username: Option<&'a str>,
|
||||
|
||||
/// The database to connect to. Defaults to the user name.
|
||||
pub database: Option<&'a str>,
|
||||
|
||||
/// Additional start-up params.
|
||||
/// <https://www.postgresql.org/docs/devel/runtime-config-client.html>
|
||||
pub params: &'a [(&'a str, &'a str)],
|
||||
}
|
||||
|
||||
impl Serialize<'_, ()> for Startup<'_> {
|
||||
fn serialize_with(&self, buf: &mut Vec<u8>, _: ()) -> Result<()> {
|
||||
buf.reserve(120);
|
||||
|
||||
buf.write_length_prefixed(|buf| {
|
||||
// The protocol version number. The most significant 16 bits are the
|
||||
// major version number (3 for the protocol described here). The least
|
||||
// significant 16 bits are the minor version number (0
|
||||
// for the protocol described here)
|
||||
buf.extend(&196_608_i32.to_be_bytes());
|
||||
|
||||
if let Some(username) = self.username {
|
||||
// The database user name to connect as.
|
||||
encode_startup_param(buf, "user", username);
|
||||
}
|
||||
|
||||
if let Some(database) = self.database {
|
||||
// The database to connect to. Defaults to the user name.
|
||||
encode_startup_param(buf, "database", database);
|
||||
}
|
||||
|
||||
for (name, value) in self.params {
|
||||
encode_startup_param(buf, name, value);
|
||||
}
|
||||
|
||||
// A zero byte is required as a terminator
|
||||
// after the last name/value pair.
|
||||
buf.push(0);
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn encode_startup_param(buf: &mut Vec<u8>, name: &str, value: &str) {
|
||||
buf.write_str_nul(name);
|
||||
buf.write_str_nul(value);
|
||||
}
|
||||
@ -1,14 +0,0 @@
|
||||
use sqlx_core::io::Serialize;
|
||||
use sqlx_core::Result;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Terminate;
|
||||
|
||||
impl Serialize<'_, ()> for Terminate {
|
||||
fn serialize_with(&self, buf: &mut Vec<u8>, _: ()) -> Result<()> {
|
||||
buf.push(b'X');
|
||||
buf.extend(&4_u32.to_be_bytes());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
77
sqlx-postgres/src/query_result.rs
Normal file
77
sqlx-postgres/src/query_result.rs
Normal file
@ -0,0 +1,77 @@
|
||||
use std::convert::TryInto;
|
||||
use std::fmt::{self, Debug, Formatter};
|
||||
use std::str::Utf8Error;
|
||||
|
||||
use bytes::Bytes;
|
||||
use bytestring::ByteString;
|
||||
use memchr::memrchr;
|
||||
use sqlx_core::QueryResult;
|
||||
|
||||
// TODO: add unit tests for command tag parsing
|
||||
|
||||
/// Represents the execution result of a command in Postgres.
|
||||
///
|
||||
/// Returned from [`execute()`][sqlx_core::Executor::execute].
|
||||
///
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
#[derive(Clone)]
|
||||
pub struct PgQueryResult {
|
||||
command: ByteString,
|
||||
rows_affected: u64,
|
||||
}
|
||||
|
||||
impl PgQueryResult {
|
||||
pub(crate) fn parse(mut command: Bytes) -> Result<Self, Utf8Error> {
|
||||
// look backwards for the first SPACE
|
||||
let offset = memrchr(b' ', &command);
|
||||
|
||||
let rows = if let Some(offset) = offset {
|
||||
atoi::atoi(&command.split_off(offset).slice(1..)).unwrap_or(0)
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
Ok(Self { command: command.try_into()?, rows_affected: rows })
|
||||
}
|
||||
|
||||
/// Returns the command tag.
|
||||
///
|
||||
/// This is usually a single word that identifies which SQL command
|
||||
/// was completed (e.g.,`INSERT`, `UPDATE`, or `MOVE`).
|
||||
#[must_use]
|
||||
pub fn command(&self) -> &str {
|
||||
&self.command
|
||||
}
|
||||
|
||||
/// Returns the number of rows inserted, deleted, updated, retrieved,
|
||||
/// changed, or copied by the SQL command.
|
||||
#[must_use]
|
||||
pub const fn rows_affected(&self) -> u64 {
|
||||
self.rows_affected
|
||||
}
|
||||
}
|
||||
|
||||
impl Debug for PgQueryResult {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("PgQueryResult")
|
||||
.field("command", &self.command())
|
||||
.field("rows_affected", &self.rows_affected())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl Extend<PgQueryResult> for PgQueryResult {
|
||||
fn extend<T: IntoIterator<Item = PgQueryResult>>(&mut self, iter: T) {
|
||||
for res in iter {
|
||||
self.rows_affected += res.rows_affected;
|
||||
self.command = res.command;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl QueryResult for PgQueryResult {
|
||||
#[inline]
|
||||
fn rows_affected(&self) -> u64 {
|
||||
self.rows_affected()
|
||||
}
|
||||
}
|
||||
55
sqlx-postgres/src/raw_value.rs
Normal file
55
sqlx-postgres/src/raw_value.rs
Normal file
@ -0,0 +1,55 @@
|
||||
use bytes::Bytes;
|
||||
use sqlx_core::RawValue;
|
||||
|
||||
use crate::{PgTypeInfo, Postgres};
|
||||
|
||||
/// The format of a raw SQL value for Postgres.
|
||||
///
|
||||
/// Postgres returns values in [`Text`] or [`Binary`] format with a
|
||||
/// configuration option in a prepared query. SQLx currently hard-codes that
|
||||
/// option to [`Binary`].
|
||||
///
|
||||
/// For simple queries, postgres only can return values in [`Text`] format.
|
||||
///
|
||||
#[derive(Debug, PartialEq, Copy, Clone)]
|
||||
pub enum PgRawValueFormat {
|
||||
Binary,
|
||||
Text,
|
||||
}
|
||||
|
||||
/// The raw representation of a SQL value for Postgres.
|
||||
// 'r: row
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
pub struct PgRawValue<'r> {
|
||||
value: Option<&'r Bytes>,
|
||||
format: PgRawValueFormat,
|
||||
type_info: PgTypeInfo,
|
||||
}
|
||||
|
||||
// 'r: row
|
||||
impl<'r> PgRawValue<'r> {
|
||||
/// Returns the type information for this value.
|
||||
#[must_use]
|
||||
pub const fn type_info(&self) -> &PgTypeInfo {
|
||||
&self.type_info
|
||||
}
|
||||
|
||||
/// Returns the format of this value.
|
||||
#[must_use]
|
||||
pub const fn format(&self) -> PgRawValueFormat {
|
||||
self.format
|
||||
}
|
||||
}
|
||||
|
||||
impl<'r> RawValue<'r> for PgRawValue<'r> {
|
||||
type Database = Postgres;
|
||||
|
||||
fn is_null(&self) -> bool {
|
||||
self.value.is_none()
|
||||
}
|
||||
|
||||
fn type_info(&self) -> &PgTypeInfo {
|
||||
&self.type_info
|
||||
}
|
||||
}
|
||||
47
sqlx-postgres/src/row.rs
Normal file
47
sqlx-postgres/src/row.rs
Normal file
@ -0,0 +1,47 @@
|
||||
use sqlx_core::{ColumnIndex, Result, Row};
|
||||
|
||||
use crate::{PgColumn, PgRawValue, Postgres};
|
||||
|
||||
/// A single row from a result set generated from MySQL.
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
pub struct PgRow {}
|
||||
|
||||
impl Row for PgRow {
|
||||
type Database = Postgres;
|
||||
|
||||
fn is_null(&self) -> bool {
|
||||
// self.is_null()
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
// self.len()
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn columns(&self) -> &[PgColumn] {
|
||||
// self.columns()
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn try_column<I: ColumnIndex<Self>>(&self, index: I) -> Result<&PgColumn> {
|
||||
// self.try_column(index)
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn column_name(&self, index: usize) -> Option<&str> {
|
||||
// self.columns.get(index).map(PgColumn::name)
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn column_index(&self, name: &str) -> Option<usize> {
|
||||
// self.columns.iter().position(|col| col.name() == name)
|
||||
todo!()
|
||||
}
|
||||
|
||||
#[allow(clippy::needless_lifetimes)]
|
||||
fn try_get_raw<'r, I: ColumnIndex<Self>>(&'r self, index: I) -> Result<PgRawValue<'r>> {
|
||||
// self.try_get_raw(index)
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
161
sqlx-postgres/src/stream.rs
Normal file
161
sqlx-postgres/src/stream.rs
Normal file
@ -0,0 +1,161 @@
|
||||
use std::convert::TryInto;
|
||||
use std::fmt::Debug;
|
||||
use std::ops::{Deref, DerefMut};
|
||||
|
||||
use bytes::Buf;
|
||||
use sqlx_core::io::{BufStream, Serialize};
|
||||
use sqlx_core::net::Stream as NetStream;
|
||||
use sqlx_core::{Result, Runtime};
|
||||
|
||||
use crate::protocol::{BackendMessage, BackendMessageType};
|
||||
|
||||
/// Reads and writes messages to and from the PostgreSQL database server.
|
||||
///
|
||||
/// The logic for serializing data structures into the messages is found
|
||||
/// mostly in `protocol/`.
|
||||
///
|
||||
/// The first byte of a message identifies the message type, and the next
|
||||
/// four bytes specify the length of the rest of the message (this length
|
||||
/// count includes itself, but not the message-type byte). The remaining
|
||||
/// contents of the message are determined by the message type. For
|
||||
/// historical reasons, the very first message sent by the client (
|
||||
/// the startup message) has no initial message-type byte.
|
||||
///
|
||||
/// <https://dev.postgres.com/doc/internals/en/postgres-packet.html>
|
||||
///
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
pub(crate) struct PgStream<Rt: Runtime> {
|
||||
stream: BufStream<Rt, NetStream<Rt>>,
|
||||
}
|
||||
|
||||
impl<Rt: Runtime> PgStream<Rt> {
|
||||
pub(crate) fn new(stream: NetStream<Rt>) -> Self {
|
||||
Self { stream: BufStream::with_capacity(stream, 4096, 1024) }
|
||||
}
|
||||
|
||||
// all communication is through a stream of messages
|
||||
pub(crate) fn write_message<'ser, T>(&'ser mut self, message: &T) -> Result<()>
|
||||
where
|
||||
T: Serialize<'ser> + Debug,
|
||||
{
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// reads and consumes a message from the stream buffer
|
||||
// assumes there is a message on the stream
|
||||
fn read_message(&mut self, size: usize) -> Result<Option<BackendMessage>> {
|
||||
// the first byte is the message type
|
||||
let ty = self.stream.get(0, 1).get_u8();
|
||||
let ty: BackendMessageType = ty.try_into()?;
|
||||
|
||||
// the next 4 bytes was the length of the message
|
||||
self.stream.consume(5);
|
||||
|
||||
// and now take the message contents
|
||||
let contents = self.stream.take(size);
|
||||
|
||||
if contents.len() != size {
|
||||
// TODO: return a database error
|
||||
// BUG: something is very wrong somewhere if this branch is executed
|
||||
// either in the SQLx Postgres driver or in the Postgres server
|
||||
unimplemented!(
|
||||
"Received {} bytes for packet but expecting {} bytes",
|
||||
contents.len(),
|
||||
size
|
||||
);
|
||||
}
|
||||
|
||||
match ty {
|
||||
BackendMessageType::ErrorResponse => {
|
||||
// TODO: return a proper error
|
||||
unimplemented!("error response");
|
||||
}
|
||||
|
||||
BackendMessageType::NotificationResponse => {
|
||||
// TODO: handle these similar to master
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
BackendMessageType::NoticeResponse => {
|
||||
// TODO: log the incoming message
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
BackendMessageType::ParameterStatus => {
|
||||
// TODO: pull out and remember server version
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
_ => Ok(Some(BackendMessage { contents, r#type: ty })),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! impl_read_message {
|
||||
($(@$blocking:ident)? $self:ident) => {{
|
||||
Ok(loop {
|
||||
// reads at least 5 bytes from the IO stream into the read buffer
|
||||
impl_read_message!($(@$blocking)? @stream $self, 0, 5);
|
||||
|
||||
// bytes 1..4 will be the length of the message
|
||||
let size = ($self.stream.get(1, 4).get_u32() - 4) as usize;
|
||||
|
||||
// read <size> bytes _after_ the header
|
||||
impl_read_message!($(@$blocking)? @stream $self, 4, size);
|
||||
|
||||
if let Some(message) = $self.read_message(size)? {
|
||||
break message;
|
||||
}
|
||||
})
|
||||
}};
|
||||
|
||||
(@blocking @stream $self:ident, $offset:expr, $n:expr) => {
|
||||
$self.stream.read($offset, $n)?;
|
||||
};
|
||||
|
||||
(@stream $self:ident, $offset:expr, $n:expr) => {
|
||||
$self.stream.read_async($offset, $n).await?;
|
||||
};
|
||||
}
|
||||
|
||||
impl<Rt: Runtime> PgStream<Rt> {
|
||||
#[cfg(feature = "async")]
|
||||
pub(crate) async fn read_message_async(&mut self) -> Result<BackendMessage>
|
||||
where
|
||||
Rt: sqlx_core::Async,
|
||||
{
|
||||
impl_read_message!(self)
|
||||
}
|
||||
|
||||
#[cfg(feature = "blocking")]
|
||||
pub(crate) fn read_message_blocking(&mut self) -> Result<BackendMessage>
|
||||
where
|
||||
Rt: sqlx_core::blocking::Runtime,
|
||||
{
|
||||
impl_read_message!(@blocking self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<Rt: Runtime> Deref for PgStream<Rt> {
|
||||
type Target = BufStream<Rt, NetStream<Rt>>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.stream
|
||||
}
|
||||
}
|
||||
|
||||
impl<Rt: Runtime> DerefMut for PgStream<Rt> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.stream
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! read_message {
|
||||
(@blocking $stream:expr) => {
|
||||
$stream.read_message_blocking()?
|
||||
};
|
||||
|
||||
($stream:expr) => {
|
||||
$stream.read_message_async().await?
|
||||
};
|
||||
}
|
||||
120
sqlx-postgres/src/type_id.rs
Normal file
120
sqlx-postgres/src/type_id.rs
Normal file
@ -0,0 +1,120 @@
|
||||
/// A unique identifier for a Postgres data type.
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||
#[cfg_attr(
|
||||
any(feature = "offline", feature = "serde"),
|
||||
derive(serde::Serialize, serde::Deserialize)
|
||||
)]
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
pub enum PgTypeId {
|
||||
Oid(u32),
|
||||
Name(&'static str),
|
||||
}
|
||||
|
||||
// Data Types
|
||||
// https://www.postgresql.org/docs/current/datatype.html
|
||||
|
||||
impl PgTypeId {
|
||||
// Boolean
|
||||
// https://www.postgresql.org/docs/current/datatype-boolean.html
|
||||
|
||||
/// The SQL standard `boolean` type.
|
||||
///
|
||||
/// Maps to `bool`.
|
||||
///
|
||||
pub const BOOLEAN: Self = Self::Oid(16);
|
||||
|
||||
// Integers
|
||||
// https://www.postgresql.org/docs/current/datatype-numeric.html#DATATYPE-INT
|
||||
|
||||
/// A 2-byte integer.
|
||||
///
|
||||
/// Compatible with any primitive integer type.
|
||||
///
|
||||
/// Maps to `i16`.
|
||||
///
|
||||
#[doc(alias = "INT2")]
|
||||
#[doc(alias = "SMALLSERIAL")]
|
||||
pub const SMALLINT: Self = Self::Oid(21);
|
||||
|
||||
/// A 4-byte integer.
|
||||
///
|
||||
/// Compatible with any primitive integer type.
|
||||
///
|
||||
/// Maps to `i32`.
|
||||
///
|
||||
#[doc(alias = "INT4")]
|
||||
#[doc(alias = "SERIAL")]
|
||||
pub const INTEGER: Self = Self::Oid(23);
|
||||
|
||||
/// An 8-byte integer.
|
||||
///
|
||||
/// Compatible with any primitive integer type.
|
||||
///
|
||||
/// Maps to `i64`.
|
||||
///
|
||||
#[doc(alias = "INT8")]
|
||||
#[doc(alias = "BIGSERIAL")]
|
||||
pub const BIGINT: Self = Self::Oid(20);
|
||||
|
||||
// Arbitrary Precision Numbers
|
||||
// https://www.postgresql.org/docs/current/datatype-numeric.html#DATATYPE-NUMERIC-DECIMAL
|
||||
|
||||
/// An exact numeric type with a user-specified precision.
|
||||
///
|
||||
/// Compatible with [`bigdecimal::BigDecimal`], [`rust_decimal::Decimal`], [`num_int::BigInt`], and any
|
||||
/// primitive integer type. Truncation or loss-of-precision is considered an error
|
||||
/// when decoding into the selected Rust integer type.
|
||||
///
|
||||
/// With a scale of `0` (e.g, `NUMERIC(17, 0)`), maps to `num_int::BigInt`; otherwise,
|
||||
/// maps to [`bigdecimal::BigDecimal`] or [`rust_decimal::Decimal`] (depending on
|
||||
/// enabled crate features).
|
||||
///
|
||||
#[doc(alias = "DECIMAL")]
|
||||
pub const NUMERIC: Self = Self::Oid(1700);
|
||||
|
||||
// Floating-Point
|
||||
// https://www.postgresql.org/docs/current/datatype-numeric.html#DATATYPE-FLOAT
|
||||
|
||||
/// A 4-byte floating-point numeric type.
|
||||
///
|
||||
/// Compatible with `f32` or `f64`.
|
||||
///
|
||||
/// Maps to `f32`.
|
||||
///
|
||||
#[doc(alias = "FLOAT4")]
|
||||
pub const REAL: Self = Self::Oid(700);
|
||||
|
||||
/// An 8-byte floating-point numeric type.
|
||||
///
|
||||
/// Compatible with `f32` or `f64`.
|
||||
///
|
||||
/// Maps to `f64`.
|
||||
///
|
||||
#[doc(alias = "FLOAT8")]
|
||||
pub const DOUBLE: Self = Self::Oid(701);
|
||||
|
||||
/// The `UNKNOWN` Postgres type. Returned for expressions that do not
|
||||
/// have a type (e.g., `SELECT $1` with no parameter type hint
|
||||
/// or `SELECT NULL`).
|
||||
pub const UNKNOWN: Self = Self::Oid(705);
|
||||
}
|
||||
|
||||
impl PgTypeId {
|
||||
#[must_use]
|
||||
pub(crate) const fn name(self) -> &'static str {
|
||||
match self {
|
||||
Self::BOOLEAN => "BOOLEAN",
|
||||
|
||||
Self::SMALLINT => "SMALLINT",
|
||||
Self::INTEGER => "INTEGER",
|
||||
Self::BIGINT => "BIGINT",
|
||||
|
||||
Self::NUMERIC => "NUMERIC",
|
||||
|
||||
Self::REAL => "REAL",
|
||||
Self::DOUBLE => "DOUBLE",
|
||||
|
||||
_ => "UNKNOWN",
|
||||
}
|
||||
}
|
||||
}
|
||||
42
sqlx-postgres/src/type_info.rs
Normal file
42
sqlx-postgres/src/type_info.rs
Normal file
@ -0,0 +1,42 @@
|
||||
use sqlx_core::TypeInfo;
|
||||
|
||||
use crate::{PgTypeId, Postgres};
|
||||
|
||||
/// Provides information about a Postgres type.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
#[cfg_attr(
|
||||
any(feature = "offline", feature = "serde"),
|
||||
derive(serde::Serialize, serde::Deserialize)
|
||||
)]
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
pub struct PgTypeInfo(pub(crate) PgTypeId);
|
||||
|
||||
impl PgTypeInfo {
|
||||
/// Returns the unique identifier for this Postgres type.
|
||||
#[must_use]
|
||||
pub const fn id(&self) -> PgTypeId {
|
||||
self.0
|
||||
}
|
||||
|
||||
/// Returns the name for this Postgres type.
|
||||
#[must_use]
|
||||
pub const fn name(&self) -> &'static str {
|
||||
self.0.name()
|
||||
}
|
||||
}
|
||||
|
||||
impl TypeInfo for PgTypeInfo {
|
||||
type Database = Postgres;
|
||||
|
||||
fn id(&self) -> PgTypeId {
|
||||
self.id()
|
||||
}
|
||||
|
||||
fn is_unknown(&self) -> bool {
|
||||
matches!(self.0, PgTypeId::UNKNOWN)
|
||||
}
|
||||
|
||||
fn name(&self) -> &'static str {
|
||||
self.name()
|
||||
}
|
||||
}
|
||||
1
sqlx-postgres/src/types.rs
Normal file
1
sqlx-postgres/src/types.rs
Normal file
@ -0,0 +1 @@
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user