From baa63d33e161dfa9f8f78e3138364bf450fbb64e Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Sat, 6 Mar 2021 13:59:24 -0800 Subject: [PATCH] refactor(postgres): baseline postgres driver against the now near-complete state of the mysql driver --- Cargo.lock | 76 +---- sqlx-core/src/raw_value.rs | 2 +- sqlx-mysql/Cargo.toml | 1 - sqlx-mysql/src/protocol/ok.rs | 2 +- sqlx-mysql/src/query_result.rs | 15 +- sqlx-mysql/src/raw_value.rs | 2 +- sqlx-mysql/src/type_info.rs | 2 +- sqlx-postgres/Cargo.toml | 18 +- sqlx-postgres/src/column.rs | 31 ++ sqlx-postgres/src/connection.rs | 124 -------- sqlx-postgres/src/connection/close.rs | 32 -- sqlx-postgres/src/connection/connect.rs | 180 ----------- sqlx-postgres/src/connection/ping.rs | 22 -- sqlx-postgres/src/connection/sasl.rs | 294 ------------------ sqlx-postgres/src/connection/stream.rs | 162 ---------- sqlx-postgres/src/database.rs | 29 +- sqlx-postgres/src/error.rs | 41 --- sqlx-postgres/src/io.rs | 3 - sqlx-postgres/src/io/write.rs | 52 ---- sqlx-postgres/src/lib.rs | 47 ++- sqlx-postgres/src/options.rs | 103 ------ sqlx-postgres/src/options/builder.rs | 82 ----- sqlx-postgres/src/options/default.rs | 38 --- sqlx-postgres/src/options/getters.rs | 55 ---- sqlx-postgres/src/options/parse.rs | 183 ----------- sqlx-postgres/src/output.rs | 15 + sqlx-postgres/src/protocol.rs | 112 +------ sqlx-postgres/src/protocol/authentication.rs | 204 ------------ .../src/protocol/backend_key_data.rs | 25 -- sqlx-postgres/src/protocol/close.rs | 36 --- sqlx-postgres/src/protocol/flush.rs | 18 -- sqlx-postgres/src/protocol/message.rs | 93 ++++++ sqlx-postgres/src/protocol/notification.rs | 28 -- sqlx-postgres/src/protocol/password.rs | 67 ---- sqlx-postgres/src/protocol/ready_for_query.rs | 51 --- sqlx-postgres/src/protocol/response.rs | 190 ----------- sqlx-postgres/src/protocol/sasl.rs | 40 --- sqlx-postgres/src/protocol/startup.rs | 64 ---- sqlx-postgres/src/protocol/terminate.rs | 14 - sqlx-postgres/src/query_result.rs | 77 +++++ sqlx-postgres/src/raw_value.rs | 55 ++++ sqlx-postgres/src/row.rs | 47 +++ sqlx-postgres/src/stream.rs | 161 ++++++++++ sqlx-postgres/src/type_id.rs | 120 +++++++ sqlx-postgres/src/type_info.rs | 42 +++ sqlx-postgres/src/types.rs | 1 + 46 files changed, 721 insertions(+), 2335 deletions(-) create mode 100644 sqlx-postgres/src/column.rs delete mode 100644 sqlx-postgres/src/connection.rs delete mode 100644 sqlx-postgres/src/connection/close.rs delete mode 100644 sqlx-postgres/src/connection/connect.rs delete mode 100644 sqlx-postgres/src/connection/ping.rs delete mode 100644 sqlx-postgres/src/connection/sasl.rs delete mode 100644 sqlx-postgres/src/connection/stream.rs delete mode 100644 sqlx-postgres/src/error.rs delete mode 100644 sqlx-postgres/src/io.rs delete mode 100644 sqlx-postgres/src/io/write.rs delete mode 100644 sqlx-postgres/src/options.rs delete mode 100644 sqlx-postgres/src/options/builder.rs delete mode 100644 sqlx-postgres/src/options/default.rs delete mode 100644 sqlx-postgres/src/options/getters.rs delete mode 100644 sqlx-postgres/src/options/parse.rs create mode 100644 sqlx-postgres/src/output.rs delete mode 100644 sqlx-postgres/src/protocol/authentication.rs delete mode 100644 sqlx-postgres/src/protocol/backend_key_data.rs delete mode 100644 sqlx-postgres/src/protocol/close.rs delete mode 100644 sqlx-postgres/src/protocol/flush.rs create mode 100644 sqlx-postgres/src/protocol/message.rs delete mode 100644 sqlx-postgres/src/protocol/notification.rs delete mode 100644 sqlx-postgres/src/protocol/password.rs delete mode 100644 sqlx-postgres/src/protocol/ready_for_query.rs delete mode 100644 sqlx-postgres/src/protocol/response.rs delete mode 100644 sqlx-postgres/src/protocol/sasl.rs delete mode 100644 sqlx-postgres/src/protocol/startup.rs delete mode 100644 sqlx-postgres/src/protocol/terminate.rs create mode 100644 sqlx-postgres/src/query_result.rs create mode 100644 sqlx-postgres/src/raw_value.rs create mode 100644 sqlx-postgres/src/row.rs create mode 100644 sqlx-postgres/src/stream.rs create mode 100644 sqlx-postgres/src/type_id.rs create mode 100644 sqlx-postgres/src/type_info.rs create mode 100644 sqlx-postgres/src/types.rs diff --git a/Cargo.lock b/Cargo.lock index 2893eb93..f541f035 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -379,16 +379,6 @@ dependencies = [ "lazy_static", ] -[[package]] -name = "crypto-mac" -version = "0.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4857fd85a0c34b3c3297875b747c1e02e06b6a0ea32dd892d8192b9ce0813ea6" -dependencies = [ - "generic-array", - "subtle", -] - [[package]] name = "ctor" version = "0.1.19" @@ -565,16 +555,6 @@ dependencies = [ "libc", ] -[[package]] -name = "hmac" -version = "0.10.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1441c6b1e930e2817404b5046f1f989899143a12bf92de603b69f4e0aee1e15" -dependencies = [ - "crypto-mac", - "digest", -] - [[package]] name = "idna" version = "0.2.2" @@ -595,12 +575,6 @@ dependencies = [ "cfg-if 1.0.0", ] -[[package]] -name = "itoa" -version = "0.4.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd25036021b0de88a0aff6b850051563c6516d0bf53f8638938edbb9de732736" - [[package]] name = "js-sys" version = "0.3.48" @@ -630,9 +604,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.87" +version = "0.2.88" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "265d751d31d6780a3f956bb5b8022feba2d94eeee5a84ba64f4212eedca42213" +checksum = "03b07a082330a35e43f63177cc01689da34fbffa0105e1246cf0311472cac73a" [[package]] name = "libm" @@ -665,17 +639,6 @@ version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ffc5c5338469d4d3ea17d269fa8ea3512ad247247c30bd2df69e68309ed0a08" -[[package]] -name = "md-5" -version = "0.9.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b5a279bb9607f9f53c22d496eade00d138d1bdcccd07d74650387cf94942a15" -dependencies = [ - "block-buffer", - "digest", - "opaque-debug", -] - [[package]] name = "memchr" version = "2.3.4" @@ -865,9 +828,9 @@ checksum = "d4fd5641d01c8f18a23da7b6fe29298ff4b55afcccdf78973b24cf3175fee32e" [[package]] name = "pin-project-lite" -version = "0.2.5" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0cf491442e4b033ed1c722cb9f0df5fcfcf4de682466c46469c36bc47dc5548a" +checksum = "dc0e1f259c92177c30a4c9d177246edd0a3568b25756a977d0632cf8fa37e905" [[package]] name = "pin-utils" @@ -1019,9 +982,9 @@ checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" [[package]] name = "serde" -version = "1.0.123" +version = "1.0.124" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92d5161132722baa40d802cc70b15262b98258453e85e5d1d365c757c73869ae" +checksum = "bd761ff957cb2a45fbb9ab3da6512de9de55872866160b23c25f1a841e99d29f" [[package]] name = "sha-1" @@ -1159,41 +1122,20 @@ version = "0.6.0-pre" dependencies = [ "anyhow", "atoi", - "base64", "bitflags", - "byteorder", "bytes", "bytestring", - "crypto-mac", - "either", + "conquer-once", "futures-executor", "futures-io", "futures-util", - "hmac", - "itoa", "log", - "md-5", "memchr", "percent-encoding", - "rand", - "rsa", - "sha-1", - "sha2", "sqlx-core", - "stringprep", "url", ] -[[package]] -name = "stringprep" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ee348cb74b87454fff4b551cbf727025810a004f88aeacae7f85b87f4e9a1c1" -dependencies = [ - "unicode-bidi", - "unicode-normalization", -] - [[package]] name = "subtle" version = "2.4.0" @@ -1202,9 +1144,9 @@ checksum = "1e81da0851ada1f3e9d4312c704aa4f8806f0f9d69faaf8df2f3464b4a9437c2" [[package]] name = "syn" -version = "1.0.60" +version = "1.0.62" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c700597eca8a5a762beb35753ef6b94df201c81cca676604f547495a0d7f0081" +checksum = "123a78a3596b24fee53a6464ce52d8ecbf62241e6294c7e7fe12086cd161f512" dependencies = [ "proc-macro2", "quote", diff --git a/sqlx-core/src/raw_value.rs b/sqlx-core/src/raw_value.rs index 4593323e..a5415bb8 100644 --- a/sqlx-core/src/raw_value.rs +++ b/sqlx-core/src/raw_value.rs @@ -9,5 +9,5 @@ pub trait RawValue<'r>: Sized { fn is_null(&self) -> bool; /// Returns the type information for this value. - fn type_info(&self) -> &'r ::TypeInfo; + fn type_info(&self) -> &::TypeInfo; } diff --git a/sqlx-mysql/Cargo.toml b/sqlx-mysql/Cargo.toml index 87837209..14833e8a 100644 --- a/sqlx-mysql/Cargo.toml +++ b/sqlx-mysql/Cargo.toml @@ -23,7 +23,6 @@ default = [] blocking = ["sqlx-core/blocking"] # async runtime -# not meant to be used directly async = ["futures-util", "sqlx-core/async", "futures-io"] [dependencies] diff --git a/sqlx-mysql/src/protocol/ok.rs b/sqlx-mysql/src/protocol/ok.rs index eb3cc71d..3ac10141 100644 --- a/sqlx-mysql/src/protocol/ok.rs +++ b/sqlx-mysql/src/protocol/ok.rs @@ -12,7 +12,7 @@ use crate::protocol::{Capabilities, Info, Status}; /// An OK packet is sent from the server to the client to signal successful completion of a command. /// As of MySQL 5.7.5, OK packes are also used to indicate EOF, and EOF packets are deprecated. #[allow(clippy::module_name_repetitions)] -#[derive(Debug)] +#[derive(Debug, Clone)] pub(crate) struct OkPacket { pub(crate) affected_rows: u64, pub(crate) last_insert_id: u64, diff --git a/sqlx-mysql/src/query_result.rs b/sqlx-mysql/src/query_result.rs index e21156a1..2df67d33 100644 --- a/sqlx-mysql/src/query_result.rs +++ b/sqlx-mysql/src/query_result.rs @@ -4,11 +4,12 @@ use sqlx_core::QueryResult; use crate::protocol::{Info, OkPacket, Status}; -/// Represents the execution result of an operation on the database server. +/// Represents the execution result of an operation in MySQL. /// /// Returned from [`execute()`][sqlx_core::Executor::execute]. /// #[allow(clippy::module_name_repetitions)] +#[derive(Clone)] pub struct MySqlQueryResult(pub(crate) OkPacket); impl MySqlQueryResult { @@ -109,18 +110,6 @@ impl Debug for MySqlQueryResult { } } -impl Default for MySqlQueryResult { - fn default() -> Self { - Self(OkPacket { - affected_rows: 0, - last_insert_id: 0, - status: Status::empty(), - warnings: 0, - info: Info::default(), - }) - } -} - impl From for MySqlQueryResult { fn from(ok: OkPacket) -> Self { Self(ok) diff --git a/sqlx-mysql/src/raw_value.rs b/sqlx-mysql/src/raw_value.rs index d0f85bea..3c2265e0 100644 --- a/sqlx-mysql/src/raw_value.rs +++ b/sqlx-mysql/src/raw_value.rs @@ -87,7 +87,7 @@ impl<'r> RawValue<'r> for MySqlRawValue<'r> { self.value.is_none() } - fn type_info(&self) -> &'r MySqlTypeInfo { + fn type_info(&self) -> &MySqlTypeInfo { self.type_info } } diff --git a/sqlx-mysql/src/type_info.rs b/sqlx-mysql/src/type_info.rs index 78f24334..5aff13bc 100644 --- a/sqlx-mysql/src/type_info.rs +++ b/sqlx-mysql/src/type_info.rs @@ -101,7 +101,7 @@ impl MySqlTypeInfo { MySqlTypeId::CHAR => "CHAR", MySqlTypeId::TEXT => "TEXT", - _ => "", + _ => "UNKNOWN", } } } diff --git a/sqlx-postgres/Cargo.toml b/sqlx-postgres/Cargo.toml index 7f1eb05f..76e5e2ae 100644 --- a/sqlx-postgres/Cargo.toml +++ b/sqlx-postgres/Cargo.toml @@ -2,7 +2,7 @@ name = "sqlx-postgres" version = "0.6.0-pre" repository = "https://github.com/launchbadge/sqlx" -description = "MySQL database driver for SQLx, the Rust SQL Toolkit." +description = "PostgreSQL database driver for SQLx, the Rust SQL Toolkit." license = "MIT OR Apache-2.0" edition = "2018" keywords = ["postgres", "sqlx", "database"] @@ -23,13 +23,12 @@ default = [] blocking = ["sqlx-core/blocking"] # async runtime -# not meant to be used directly async = ["futures-util", "sqlx-core/async", "futures-io"] [dependencies] +atoi = "0.4.0" sqlx-core = { version = "0.6.0-pre", path = "../sqlx-core" } futures-util = { version = "0.3.8", optional = true } -either = "1.6.1" log = "0.4.11" bytestring = "1.0.0" url = "2.2.0" @@ -38,20 +37,9 @@ futures-io = { version = "0.3", optional = true } bytes = "1.0" memchr = "2.3" bitflags = "1.2" -sha-1 = "0.9.2" -sha2 = "0.9.2" -rsa = "0.3.0" -base64 = "0.13.0" -rand = "0.7" -itoa = "0.4.7" -atoi = "0.4.0" -byteorder = { version = "1.4.2", default-features = false, features = [ "std" ] } -md-5 = { version = "0.9.1", default-features = false } -hmac = { version = "0.10.1", default-features = false } -stringprep = "0.1.2" -crypto-mac = "0.10.0" [dev-dependencies] sqlx-core = { version = "0.6.0-pre", path = "../sqlx-core", features = ["_mock"] } futures-executor = "0.3.8" anyhow = "1.0.37" +conquer-once = "0.3.2" diff --git a/sqlx-postgres/src/column.rs b/sqlx-postgres/src/column.rs new file mode 100644 index 00000000..1edf267c --- /dev/null +++ b/sqlx-postgres/src/column.rs @@ -0,0 +1,31 @@ +use bytestring::ByteString; +use sqlx_core::{Column, Database}; + +use crate::{PgTypeInfo, Postgres}; + +// TODO: inherent methods from + +/// Represents a column from a query in Postgres. +#[allow(clippy::module_name_repetitions)] +#[derive(Debug, Clone)] +pub struct PgColumn { + index: usize, + name: ByteString, + type_info: PgTypeInfo, +} + +impl Column for PgColumn { + type Database = Postgres; + + fn name(&self) -> &str { + &self.name + } + + fn index(&self) -> usize { + self.index + } + + fn type_info(&self) -> &PgTypeInfo { + &self.type_info + } +} diff --git a/sqlx-postgres/src/connection.rs b/sqlx-postgres/src/connection.rs deleted file mode 100644 index 3527e340..00000000 --- a/sqlx-postgres/src/connection.rs +++ /dev/null @@ -1,124 +0,0 @@ -use std::fmt::{self, Debug, Formatter}; - -use sqlx_core::io::BufStream; -use sqlx_core::net::Stream as NetStream; -use sqlx_core::{Close, Connect, Connection, Runtime}; - -use crate::{Postgres, PostgresConnectOptions}; - -#[macro_use] -mod sasl; - -mod close; -mod connect; -mod ping; -mod stream; - -/// A single connection (also known as a session) to a PostgreSQL database server. -#[allow(clippy::module_name_repetitions)] -pub struct PostgresConnection -where - Rt: Runtime, -{ - stream: BufStream>, - - // process id of this backend - // used to send cancel requests - #[allow(dead_code)] - process_id: u32, - - // secret key of this backend - // used to send cancel requests - #[allow(dead_code)] - secret_key: u32, -} - -impl PostgresConnection -where - Rt: Runtime, -{ - pub(crate) fn new(stream: NetStream) -> Self { - Self { stream: BufStream::with_capacity(stream, 4096, 1024), process_id: 0, secret_key: 0 } - } -} - -impl Debug for PostgresConnection -where - Rt: Runtime, -{ - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - f.debug_struct("PostgresConnection").finish() - } -} - -impl Connection for PostgresConnection -where - Rt: Runtime, -{ - type Database = Postgres; - - #[cfg(feature = "async")] - fn ping(&mut self) -> futures_util::future::BoxFuture<'_, sqlx_core::Result<()>> - where - Rt: sqlx_core::Async, - { - Box::pin(self.ping_async()) - } -} - -impl Connect for PostgresConnection { - type Options = PostgresConnectOptions; - - #[cfg(feature = "async")] - fn connect(url: &str) -> futures_util::future::BoxFuture<'_, sqlx_core::Result> - where - Self: Sized, - Rt: sqlx_core::Async, - { - use sqlx_core::ConnectOptions; - - let options = url.parse::(); - Box::pin(async move { options?.connect().await }) - } -} - -impl Close for PostgresConnection { - #[cfg(feature = "async")] - fn close(self) -> futures_util::future::BoxFuture<'static, sqlx_core::Result<()>> - where - Rt: sqlx_core::Async, - { - Box::pin(self.close_async()) - } -} - -#[cfg(feature = "blocking")] -mod blocking { - use sqlx_core::blocking::{Close, Connect, Connection, Runtime}; - - use super::{PostgresConnectOptions, PostgresConnection}; - - impl Connection for PostgresConnection { - #[inline] - fn ping(&mut self) -> sqlx_core::Result<()> { - self.ping() - } - } - - impl Connect for PostgresConnection { - #[inline] - fn connect(url: &str) -> sqlx_core::Result - where - Self: Sized, - { - Self::connect(&url.parse::>()?) - } - } - - impl Close for PostgresConnection { - #[inline] - fn close(self) -> sqlx_core::Result<()> { - self.close() - } - } -} diff --git a/sqlx-postgres/src/connection/close.rs b/sqlx-postgres/src/connection/close.rs deleted file mode 100644 index 7bbefeb2..00000000 --- a/sqlx-postgres/src/connection/close.rs +++ /dev/null @@ -1,32 +0,0 @@ -use sqlx_core::{io::Stream, Result, Runtime}; - -use crate::protocol::Terminate; - -impl super::PostgresConnection -where - Rt: Runtime, -{ - #[cfg(feature = "async")] - pub(crate) async fn close_async(mut self) -> Result<()> - where - Rt: sqlx_core::Async, - { - self.write_packet(&Terminate)?; - self.stream.flush_async().await?; - self.stream.shutdown_async().await?; - - Ok(()) - } - - #[cfg(feature = "blocking")] - pub(crate) fn close(mut self) -> Result<()> - where - Rt: sqlx_core::blocking::Runtime, - { - self.write_packet(&Terminate)?; - self.stream.flush()?; - self.stream.shutdown()?; - - Ok(()) - } -} diff --git a/sqlx-postgres/src/connection/connect.rs b/sqlx-postgres/src/connection/connect.rs deleted file mode 100644 index 294b751b..00000000 --- a/sqlx-postgres/src/connection/connect.rs +++ /dev/null @@ -1,180 +0,0 @@ -//! Implements the connection phase. -//! -//! The connection phase (establish) performs these tasks: -//! -//! - exchange the capabilities of client and server -//! - setup SSL communication channel if requested -//! - authenticate the client against the server -//! -//! The server may immediately send an ERR packet and finish the handshake -//! or send a `Handshake`. -//! -//! https://dev.postgres.com/doc/internals/en/connection-phase.html -//! -use hmac::{Hmac, Mac, NewMac}; -use sha2::{Digest, Sha256}; -use sqlx_core::net::Stream as NetStream; -use sqlx_core::Error; -use sqlx_core::Result; - -use crate::protocol::{ - Authentication, BackendKeyData, Message, MessageType, Password, ReadyForQuery, - SaslInitialResponse, SaslResponse, Startup, -}; -use crate::{PostgresConnectOptions, PostgresConnection, PostgresDatabaseError}; - -macro_rules! connect { - (@blocking @tcp $options:ident) => { - NetStream::connect($options.address.as_ref())?; - }; - - (@tcp $options:ident) => { - NetStream::connect_async($options.address.as_ref()).await?; - }; - - (@blocking @packet $self:ident) => { - $self.read_packet()?; - }; - - (@packet $self:ident) => { - $self.read_packet_async().await?; - }; - - ($(@$blocking:ident)? $options:ident) => {{ - // open a network stream to the database server - let stream = connect!($(@$blocking)? @tcp $options); - - // construct a around the network stream - // wraps the stream in a to buffer read and write - let mut self_ = Self::new(stream); - - // To begin a session, a frontend opens a connection to the server - // and sends a startup message. - - let mut params = vec![ // Sets the display format for date and time values, as well as the rules for interpreting ambiguous date input values. ("DateStyle", "ISO, MDY"), - // Sets the client-side encoding (character set). - // - ("client_encoding", "UTF8"), - // Sets the time zone for displaying and interpreting time stamps. - ("TimeZone", "UTC"), - // Adjust postgres to return precise values for floats - // NOTE: This is default in postgres 12+ - ("extra_float_digits", "3"), - ]; - - // if let Some(ref application_name) = $options.get_application_name() { - // params.push(("application_name", application_name)); - // } - - self_.write_packet(&Startup { - username: $options.get_username(), - database: $options.get_database(), - params: ¶ms, - })?; - - // The server then uses this information and the contents of - // its configuration files (such as pg_hba.conf) to determine whether the connection is - // provisionally acceptable, and what additional - // authentication is required (if any). - - let mut process_id = 0; - let mut secret_key = 0; - let transaction_status; - - loop { - let message: Message = connect!($(@$blocking)? @packet self_); - match message.r#type { - MessageType::Authentication => match message.decode()? { - Authentication::Ok => { - // the authentication exchange is successfully completed - // do nothing; no more information is required to continue - } - - Authentication::CleartextPassword => { - // The frontend must now send a [PasswordMessage] containing the - // password in clear-text form. - - self_ - .write_packet(&Password::Cleartext( - $options.get_password().unwrap_or_default(), - ))?; - } - - Authentication::Md5Password(body) => { - // The frontend must now send a [PasswordMessage] containing the - // password (with user name) encrypted via MD5, then encrypted again - // using the 4-byte random salt specified in the - // [AuthenticationMD5Password] message. - - self_ - .write_packet(&Password::Md5 { - username: $options.get_username().unwrap_or_default(), - password: $options.get_password().unwrap_or_default(), - salt: body.salt, - })?; - } - - Authentication::Sasl(body) => { - sasl_authenticate!($(@$blocking)? self_, $options, body) - } - - method => { - return Err(Error::configuration_msg(format!( - "unsupported authentication method: {:?}", - method - ))); - } - }, - - MessageType::BackendKeyData => { - // provides secret-key data that the frontend must save if it wants to be - // able to issue cancel requests later - - let data: BackendKeyData = message.decode()?; - - process_id = data.process_id; - secret_key = data.secret_key; - } - - MessageType::ReadyForQuery => { - let ready: ReadyForQuery = message.decode()?; - - // start-up is completed. The frontend can now issue commands - transaction_status = ready.transaction_status; - - break; - } - - _ => { - return Err(Error::configuration_msg(format!( - "establish: unexpected message: {:?}", - message.r#type - ))) - } - } - } - - Ok(self_) - }}; -} - -impl PostgresConnection -where - Rt: sqlx_core::Runtime, -{ - #[cfg(feature = "async")] - pub(crate) async fn connect_async(options: &PostgresConnectOptions) -> Result - where - Rt: sqlx_core::Async, - { - connect!(options) - } - - #[cfg(feature = "blocking")] - pub(crate) fn connect(options: &PostgresConnectOptions) -> Result - where - Rt: sqlx_core::blocking::Runtime, - { - connect!(@blocking options) - } -} diff --git a/sqlx-postgres/src/connection/ping.rs b/sqlx-postgres/src/connection/ping.rs deleted file mode 100644 index 9e3c5ee0..00000000 --- a/sqlx-postgres/src/connection/ping.rs +++ /dev/null @@ -1,22 +0,0 @@ -use sqlx_core::{Result, Runtime}; - -impl super::PostgresConnection -where - Rt: Runtime, -{ - #[cfg(feature = "async")] - pub(crate) async fn ping_async(&mut self) -> Result<()> - where - Rt: sqlx_core::Async, - { - todo!(); - } - - #[cfg(feature = "blocking")] - pub(crate) fn ping(&mut self) -> Result<()> - where - Rt: sqlx_core::blocking::Runtime, - { - todo!(); - } -} diff --git a/sqlx-postgres/src/connection/sasl.rs b/sqlx-postgres/src/connection/sasl.rs deleted file mode 100644 index d97e9c73..00000000 --- a/sqlx-postgres/src/connection/sasl.rs +++ /dev/null @@ -1,294 +0,0 @@ -use hmac::{Hmac, Mac, NewMac}; -use rand::Rng; -use sha2::digest::Digest; -use sha2::Sha256; -use sqlx_core::Error; -use sqlx_core::Result; -use sqlx_core::Runtime; - -use crate::protocol::{ - Authentication, AuthenticationSasl, Message, MessageType, SaslInitialResponse, SaslResponse, -}; -use crate::PostgresConnectOptions; -use crate::PostgresConnection; -use crate::PostgresDatabaseError; - -pub(super) const GS2_HEADER: &str = "n,,"; -pub(super) const CHANNEL_ATTR: &str = "c"; -pub(super) const USERNAME_ATTR: &str = "n"; -pub(super) const CLIENT_PROOF_ATTR: &str = "p"; -pub(super) const NONCE_ATTR: &str = "r"; - -pub(super) fn sasl_init_response( - self_: &mut PostgresConnection, - options: &PostgresConnectOptions, - data: AuthenticationSasl, -) -> Result<(String, String, String)> { - let mut has_sasl = false; - let mut has_sasl_plus = false; - let mut unknown = Vec::new(); - - for mechanism in data.mechanisms() { - match mechanism { - "SCRAM-SHA-256" => { - has_sasl = true; - } - - "SCRAM-SHA-256-PLUS" => { - has_sasl_plus = true; - } - - _ => { - unknown.push(mechanism.to_owned()); - } - } - } - - if !has_sasl_plus && !has_sasl { - return Err(Error::connect(PostgresDatabaseError::protocol(format!( - "unsupported SASL authentication mechanisms: {}", - unknown.join(", ") - )))); - } - - // channel-binding = "c=" base64 - let channel_binding = format!( - "{}={}", - crate::connection::sasl::CHANNEL_ATTR, - base64::encode(crate::connection::sasl::GS2_HEADER) - ); - - // "n=" saslname ;; Usernames are prepared using SASLprep. - let username = format!("{}={}", USERNAME_ATTR, options.get_username().unwrap_or_default()); - let username = match stringprep::saslprep(&username) { - Ok(v) => v, - Err(err) => { - return Err(Error::connect(PostgresDatabaseError::protocol(format!( - "failed to sasl prep the username: {:?}", - err - )))); - } - }; - - // nonce = "r=" c-nonce [s-nonce] ;; Second part provided by server. - let nonce = gen_nonce(); - - // client-first-message-bare = [reserved-mext ","] username "," nonce ["," extensions] - let client_first_message_bare = - format!("{username},{nonce}", username = username, nonce = nonce); - - let client_first_message = format!( - "{gs2_header}{client_first_message_bare}", - gs2_header = GS2_HEADER, - client_first_message_bare = client_first_message_bare - ); - - self_.write_packet(&SaslInitialResponse { - response: client_first_message.clone(), - plus: false, - })?; - - Ok((channel_binding, client_first_message_bare, client_first_message)) -} - -pub(super) fn sasl_response<'a, Rt: Runtime>( - self_: &mut PostgresConnection, - message: Message, - options: &PostgresConnectOptions, - channel_binding: &String, - client_first_message_bare: &'a String, -) -> Result> { - let cont = match message.r#type { - MessageType::Authentication => match message.decode()? { - Authentication::SaslContinue(data) => data, - - auth => { - return Err(Error::connect(PostgresDatabaseError::protocol(format!( - "expected SASLContinue but received {:?}", - auth - )))); - } - }, - - r#type => { - return Err(Error::connect(PostgresDatabaseError::protocol(format!( - "Expected an authencation message type, found {:?}", - r#type - )))); - } - }; - - // SaltedPassword := Hi(Normalize(password), salt, i) - let salted_password = - hi(options.get_password().unwrap_or_default(), &cont.salt, cont.iterations)?; - - // ClientKey := HMAC(SaltedPassword, "Client Key") - let mut mac = Hmac::::new_varkey(&salted_password) - .map_err(|err| Error::connect(crate::PostgresDatabaseError::from(err)))?; - mac.update(b"Client Key"); - - let client_key = mac.finalize().into_bytes(); - - // StoredKey := H(ClientKey) - let stored_key = Sha256::digest(&client_key); - - // client-final-message-without-proof - let client_final_message_wo_proof = format!( - "{channel_binding},r={nonce}", - channel_binding = channel_binding, - nonce = &cont.nonce - ); - - // AuthMessage := client-first-message-bare + "," + server-first-message + "," + client-final-message-without-proof - let auth_message = format!( - "{client_first_message_bare},{server_first_message},{client_final_message_wo_proof}", - client_first_message_bare = client_first_message_bare, - server_first_message = cont.message, - client_final_message_wo_proof = client_final_message_wo_proof - ); - - // ClientSignature := HMAC(StoredKey, AuthMessage) - let mut mac = Hmac::::new_varkey(&stored_key) - .map_err(|err| Error::connect(crate::PostgresDatabaseError::from(err)))?; - mac.update(&auth_message.as_bytes()); - - let client_signature = mac.finalize().into_bytes(); - - // ClientProof := ClientKey XOR ClientSignature - let client_proof: Vec = - client_key.iter().zip(client_signature.iter()).map(|(&a, &b)| a ^ b).collect(); - - // ServerKey := HMAC(SaltedPassword, "Server Key") - let mut mac = Hmac::::new_varkey(&salted_password) - .map_err(|err| Error::connect(crate::PostgresDatabaseError::from(err)))?; - mac.update(b"Server Key"); - - let server_key = mac.finalize().into_bytes(); - - // ServerSignature := HMAC(ServerKey, AuthMessage) - let mut mac = Hmac::::new_varkey(&server_key) - .map_err(|err| Error::connect(crate::PostgresDatabaseError::from(err)))?; - mac.update(&auth_message.as_bytes()); - - // client-final-message = client-final-message-without-proof "," proof - let client_final_message = format!( - "{client_final_message_wo_proof},{client_proof_attr}={client_proof}", - client_final_message_wo_proof = client_final_message_wo_proof, - client_proof_attr = crate::connection::sasl::CLIENT_PROOF_ATTR, - client_proof = base64::encode(&client_proof) - ); - - self_.write_packet(&SaslResponse(client_first_message_bare))?; - - Ok(mac) -} - -pub(super) fn sasl_final(message: Message, mac: Hmac) -> Result<()> { - let data = match message.r#type { - MessageType::Authentication => match message.decode()? { - Authentication::SaslFinal(data) => data, - - auth => { - return Err(Error::connect(PostgresDatabaseError::protocol(format!( - "expected SASLContinue but received {:?}", - auth - )))); - } - }, - - r#type => { - return Err(Error::connect(PostgresDatabaseError::protocol(format!( - "Expected an authencation message type, found {:?}", - r#type - )))); - } - }; - - // authentication is only considered valid if this verification passes - mac.verify(&data.verifier) - .map_err(|err| Error::connect(crate::PostgresDatabaseError::from(err)))?; - - Ok(()) -} - -macro_rules! sasl_authenticate { - (@blocking @packet $self:ident) => { - $self.read_packet()?; - }; - - (@packet $self:ident) => { - $self.read_packet_async().await?; - }; - - ($(@$blocking:ident)? $self:ident, $options:ident, $data:ident) => {{ - let ( - channel_binding, - client_first_message_bare, - client_first_message, - ) = crate::connection::sasl::sasl_init_response(&mut $self, $options, $data)?; - - let message: Message = sasl_authenticate!($(@$blocking)? @packet $self); - let mac = crate::connection::sasl::sasl_response( - &mut $self, - message, - $options, - &channel_binding, - &client_first_message_bare, - )?; - - let message: Message = sasl_authenticate!($(@$blocking)? @packet $self); - crate::connection::sasl::sasl_final( - message, - mac - )?; - }}; -} - -// nonce is a sequence of random printable bytes -fn gen_nonce() -> String { - let mut rng = rand::thread_rng(); - let count = rng.gen_range(64, 128); - - // printable = %x21-2B / %x2D-7E - // ;; Printable ASCII except ",". - // ;; Note that any "printable" is also - // ;; a valid "value". - let nonce: String = std::iter::repeat(()) - .map(|()| { - let mut c = rng.gen_range(0x21, 0x7F) as u8; - - while c == 0x2C { - c = rng.gen_range(0x21, 0x7F) as u8; - } - - c - }) - .take(count) - .map(|c| c as char) - .collect(); - - rng.gen_range(32, 128); - format!("{}={}", crate::connection::sasl::NONCE_ATTR, nonce) -} - -// Hi(str, salt, i): -fn hi<'a>(s: &'a str, salt: &'a [u8], iter_count: u32) -> Result<[u8; 32]> { - let mut mac = Hmac::::new_varkey(s.as_bytes()) - .map_err(|err| Error::connect(crate::PostgresDatabaseError::from(err)))?; - - mac.update(&salt); - mac.update(&1u32.to_be_bytes()); - - let mut u = mac.finalize().into_bytes(); - let mut hi = u; - - for _ in 1..iter_count { - let mut mac = Hmac::::new_varkey(s.as_bytes()) - .map_err(|err| Error::connect(crate::PostgresDatabaseError::from(err)))?; - mac.update(u.as_slice()); - u = mac.finalize().into_bytes(); - hi = hi.iter().zip(u.iter()).map(|(&a, &b)| a ^ b).collect(); - } - - Ok(hi.into()) -} diff --git a/sqlx-postgres/src/connection/stream.rs b/sqlx-postgres/src/connection/stream.rs deleted file mode 100644 index eeec4a9b..00000000 --- a/sqlx-postgres/src/connection/stream.rs +++ /dev/null @@ -1,162 +0,0 @@ -//! Reads and writes packets to and from the PostgreSQL database server. -//! -//! The logic for serializing data structures into the packets is found -//! mostly in `protocol/`. -//! -//! Packets in PostgreSQL are prefixed by 4 bytes. -//! 3 for length (in LE) and a sequence id. -//! -//! Packets may only be as large as the communicated size in the initial -//! `HandshakeResponse`. By default, SQLx configures its chunk size to 16M. Sending -//! a larger payload is simply sending completely "full" packets, one after the -//! other, with an increasing sequence id. -//! -//! In other words, when we sent data, we: -//! -//! - Split the data into "packets" of size `2 ** 24 - 1` bytes. -//! -//! - Prepend each packet with a **packet header**, consisting of the length of that packet, -//! and the sequence number. -//! -//! https://dev.postgres.com/doc/internals/en/postgres-packet.html -//! -use std::convert::TryFrom; -use std::fmt::Debug; - -use bytes::{Buf, Bytes}; -use log::Level; -use sqlx_core::io::{Deserialize, Serialize}; -use sqlx_core::{Result, Runtime}; - -use crate::protocol::{Message, MessageType, Notice, PgSeverity}; -use crate::PostgresConnection; - -impl PostgresConnection -where - Rt: Runtime, -{ - pub(super) fn write_packet<'ser, T>(&'ser mut self, packet: &T) -> Result<()> - where - T: Serialize<'ser, ()> + Debug, - { - log::trace!("write > {:?}", packet); - - let buf = self.stream.buffer(); - packet.serialize_with(buf, ())?; - - Ok(()) - } -} - -macro_rules! read_packet { - ($(@$blocking:ident)? $self:ident) => {{ - loop { - read_packet!($(@$blocking)? @stream $self, 0, 5); - - // peek at the messaage type and payload size - let r#type = MessageType::try_from(*$self.stream.get(0, 1))?; - let size = (u32::from_be_bytes($self.stream.get(1, 4)) - 4) as usize; - - read_packet!($(@$blocking)? @stream $self, 5, size); - - // take the whole packet - $self.stream.consume(5); - let contents = $self.stream.take(size); - - let message = Message { r#type, contents }; - - match message.r#type { - MessageType::ErrorResponse => { - // An error returned from the database server. - // return Err(PgDatabaseError(message.decode()?).into()); - panic!("got error response"); - } - - MessageType::NotificationResponse => { - // if let Some(buffer) = &mut self.notifications { - // let notification: Notification = message.decode()?; - // let _ = self.write_packet(notification); - - // continue; - // } - continue; - } - - MessageType::ParameterStatus => { - // informs the frontend about the current (initial) - // setting of backend parameters - - // we currently have no use for that data so we promptly ignore this message - continue; - } - - MessageType::NoticeResponse => { - // do we need this to be more configurable? - // if you are reading this comment and think so, open an issue - - let notice: Notice = message.decode()?; - - let lvl = match notice.severity() { - PgSeverity::Fatal | PgSeverity::Panic | PgSeverity::Error => Level::Error, - PgSeverity::Warning => Level::Warn, - PgSeverity::Notice => Level::Info, - PgSeverity::Debug => Level::Debug, - PgSeverity::Info => Level::Trace, - PgSeverity::Log => Level::Trace, - }; - - if lvl <= log::STATIC_MAX_LEVEL && lvl <= log::max_level() { - log::logger().log( - &log::Record::builder() - .args(format_args!("{}", notice.message())) - .level(lvl) - .module_path_static(Some("sqlx::postgres::notice")) - .file_static(Some(file!())) - .line(Some(line!())) - .build(), - ); - } - - continue; - } - - _ => {} - } - - return T::deserialize_with(message.contents, ()); - } - }}; - - (@blocking @stream $self:ident, $offset:expr, $n:expr) => { - $self.stream.read($offset, $n)?; - }; - - (@stream $self:ident, $offset:expr, $n:expr) => { - $self.stream.read_async($offset, $n).await?; - }; -} - -impl PostgresConnection -where - Rt: Runtime, -{ - #[cfg(feature = "async")] - - pub(super) async fn read_packet_async<'de, T>(&'de mut self) -> Result - where - T: Deserialize<'de, ()> + Debug, - Rt: sqlx_core::Async, - { - read_packet!(self) - } - - #[cfg(feature = "blocking")] - - pub(super) fn read_packet<'de, T>(&'de mut self) -> Result - where - T: Deserialize<'de, ()> + Debug, - Rt: sqlx_core::blocking::Runtime, - { - read_packet!(@blocking self) - } -} diff --git a/sqlx-postgres/src/database.rs b/sqlx-postgres/src/database.rs index b3c6508e..4149fb5d 100644 --- a/sqlx-postgres/src/database.rs +++ b/sqlx-postgres/src/database.rs @@ -1,15 +1,30 @@ -use sqlx_core::{Database, HasOutput, Runtime}; +use sqlx_core::database::{HasOutput, HasRawValue}; +use sqlx_core::Database; + +use super::{PgColumn, PgOutput, PgQueryResult, PgRawValue, PgRow, PgTypeId, PgTypeInfo}; #[derive(Debug)] pub struct Postgres; -impl Database for Postgres -where - Rt: Runtime, -{ - type Connection = super::PostgresConnection; +impl Database for Postgres { + type Column = PgColumn; + + type Row = PgRow; + + type QueryResult = PgQueryResult; + + type TypeInfo = PgTypeInfo; + + type TypeId = PgTypeId; } +// 'x: execution impl<'x> HasOutput<'x> for Postgres { - type Output = &'x mut Vec; + type Output = PgOutput<'x>; +} + +// 'r: row +impl<'r> HasRawValue<'r> for Postgres { + type Database = Self; + type RawValue = PgRawValue<'r>; } diff --git a/sqlx-postgres/src/error.rs b/sqlx-postgres/src/error.rs deleted file mode 100644 index 709d4165..00000000 --- a/sqlx-postgres/src/error.rs +++ /dev/null @@ -1,41 +0,0 @@ -use std::error::Error as StdError; -use std::fmt::{self, Display, Formatter}; - -use sqlx_core::DatabaseError; - -/// An error returned from the PostgreSQL database server. -#[allow(clippy::module_name_repetitions)] -#[derive(Debug)] -pub struct PostgresDatabaseError(String); - -impl PostgresDatabaseError { - pub(crate) fn protocol(msg: String) -> PostgresDatabaseError { - PostgresDatabaseError(msg) - } -} - -impl From for PostgresDatabaseError { - fn from(err: crypto_mac::InvalidKeyLength) -> Self { - PostgresDatabaseError::protocol(err.to_string()) - } -} - -impl From for PostgresDatabaseError { - fn from(err: crypto_mac::MacError) -> Self { - PostgresDatabaseError::protocol(err.to_string()) - } -} - -impl DatabaseError for PostgresDatabaseError { - fn message(&self) -> &str { - &self.0 - } -} - -impl Display for PostgresDatabaseError { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.0) - } -} - -impl StdError for PostgresDatabaseError {} diff --git a/sqlx-postgres/src/io.rs b/sqlx-postgres/src/io.rs deleted file mode 100644 index 51f04761..00000000 --- a/sqlx-postgres/src/io.rs +++ /dev/null @@ -1,3 +0,0 @@ -mod write; - -pub(crate) use write::PgBufMutExt; diff --git a/sqlx-postgres/src/io/write.rs b/sqlx-postgres/src/io/write.rs deleted file mode 100644 index 4b13b7c9..00000000 --- a/sqlx-postgres/src/io/write.rs +++ /dev/null @@ -1,52 +0,0 @@ -pub trait PgBufMutExt { - fn write_length_prefixed(&mut self, f: F) - where - F: FnOnce(&mut Vec); - - fn write_statement_name(&mut self, id: u32); - - fn write_portal_name(&mut self, id: Option); -} - -impl PgBufMutExt for Vec { - // writes a length-prefixed message, this is used when encoding nearly all messages as postgres - // wants us to send the length of the often-variable-sized messages up front - fn write_length_prefixed(&mut self, f: F) - where - F: FnOnce(&mut Vec), - { - // reserve space to write the prefixed length - let offset = self.len(); - self.extend(&[0; 4]); - - // write the main body of the message - f(self); - - // now calculate the size of what we wrote and set the length value - let size = (self.len() - offset) as i32; - self[offset..(offset + 4)].copy_from_slice(&size.to_be_bytes()); - } - - // writes a statement name by ID - #[inline] - fn write_statement_name(&mut self, id: u32) { - // N.B. if you change this don't forget to update it in ../describe.rs - self.extend(b"sqlx_s_"); - - itoa::write(&mut *self, id).unwrap(); - - self.push(0); - } - - // writes a portal name by ID - #[inline] - fn write_portal_name(&mut self, id: Option) { - if let Some(id) = id { - self.extend(b"sqlx_p_"); - - itoa::write(&mut *self, id).unwrap(); - } - - self.push(0); - } -} diff --git a/sqlx-postgres/src/lib.rs b/sqlx-postgres/src/lib.rs index e8ece223..5a34e279 100644 --- a/sqlx-postgres/src/lib.rs +++ b/sqlx-postgres/src/lib.rs @@ -1,6 +1,6 @@ //! [PostgreSQL] database driver. //! -//! [PostgreSQL]: https://www.postgres.com/ +//! [PostgreSQL]: https://www.postgresql.org/ //! #![cfg_attr(doc_cfg, feature(doc_cfg))] #![cfg_attr(not(any(feature = "async", feature = "blocking")), allow(unused))] @@ -17,18 +17,45 @@ #![warn(clippy::use_self)] #![warn(clippy::useless_let_if_seq)] #![allow(clippy::doc_markdown)] +#![allow(clippy::missing_errors_doc)] +#![allow(clippy::missing_panics_doc)] -mod connection; +use sqlx_core::Arguments; + +#[macro_use] +mod stream; + +mod column; +// mod connection; mod database; -mod error; -mod io; -mod options; +// mod error; +// mod io; +// mod options; +mod output; mod protocol; +mod query_result; +// mod raw_statement; +mod raw_value; +mod row; +// mod transaction; +mod type_id; +mod type_info; +pub mod types; -#[cfg(test)] -mod mock; +// #[cfg(test)] +// mod mock; -pub use connection::PostgresConnection; +pub use column::PgColumn; +// pub use connection::PgConnection; pub use database::Postgres; -pub use error::PostgresDatabaseError; -pub use options::PostgresConnectOptions; +// pub use error::PgDatabaseError; +// pub use options::PgConnectOptions; +pub use output::PgOutput; +pub use query_result::PgQueryResult; +pub use raw_value::{PgRawValue, PgRawValueFormat}; +pub use row::PgRow; +pub use type_id::PgTypeId; +pub use type_info::PgTypeInfo; + +// 'a: argument values +pub type PgArguments<'a> = Arguments<'a, Postgres>; diff --git a/sqlx-postgres/src/options.rs b/sqlx-postgres/src/options.rs deleted file mode 100644 index 06ee91bd..00000000 --- a/sqlx-postgres/src/options.rs +++ /dev/null @@ -1,103 +0,0 @@ -use std::fmt::{self, Debug, Formatter}; -use std::marker::PhantomData; -use std::path::PathBuf; - -use either::Either; -use sqlx_core::{ConnectOptions, Runtime}; - -use crate::PostgresConnection; - -mod builder; -mod default; -mod getters; -mod parse; - -// TODO: RSA Public Key (to avoid the key exchange for caching_sha2 and sha256 plugins) - -/// Options which can be used to configure how a PostgreSQL connection is opened. -/// -#[allow(clippy::module_name_repetitions)] -pub struct PostgresConnectOptions -where - Rt: Runtime, -{ - runtime: PhantomData, - pub(crate) address: Either<(String, u16), PathBuf>, - username: Option, - password: Option, - database: Option, - timezone: String, - charset: String, -} - -impl Clone for PostgresConnectOptions -where - Rt: Runtime, -{ - fn clone(&self) -> Self { - Self { - runtime: PhantomData, - address: self.address.clone(), - username: self.username.clone(), - password: self.password.clone(), - database: self.database.clone(), - timezone: self.timezone.clone(), - charset: self.charset.clone(), - } - } -} - -impl Debug for PostgresConnectOptions -where - Rt: Runtime, -{ - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - f.debug_struct("PostgresConnectOptions") - .field( - "address", - &self - .address - .as_ref() - .map_left(|(host, port)| format!("{}:{}", host, port)) - .map_right(|socket| socket.display()), - ) - .field("username", &self.username) - .field("password", &self.password) - .field("database", &self.database) - .field("timezone", &self.timezone) - .field("charset", &self.charset) - .finish() - } -} - -impl ConnectOptions for PostgresConnectOptions -where - Rt: Runtime, -{ - type Connection = PostgresConnection; - - #[cfg(feature = "async")] - fn connect(&self) -> futures_util::future::BoxFuture<'_, sqlx_core::Result> - where - Self::Connection: Sized, - Rt: sqlx_core::Async, - { - Box::pin(PostgresConnection::::connect_async(self)) - } -} - -#[cfg(feature = "blocking")] -mod blocking { - use sqlx_core::blocking::{ConnectOptions, Runtime}; - - use super::{PostgresConnectOptions, PostgresConnection}; - - impl ConnectOptions for PostgresConnectOptions { - fn connect(&self) -> sqlx_core::Result - where - Self::Connection: Sized, - { - >::connect(self) - } - } -} diff --git a/sqlx-postgres/src/options/builder.rs b/sqlx-postgres/src/options/builder.rs deleted file mode 100644 index 07feb7de..00000000 --- a/sqlx-postgres/src/options/builder.rs +++ /dev/null @@ -1,82 +0,0 @@ -use std::mem; -use std::path::{Path, PathBuf}; - -use either::Either; -use sqlx_core::Runtime; - -impl super::PostgresConnectOptions -where - Rt: Runtime, -{ - /// Sets the hostname of the database server. - /// - /// If the hostname begins with a slash (`/`), it is interpreted as the absolute path - /// to a Unix domain socket file instead of a hostname of a server. - /// - /// Defaults to `localhost`. - /// - pub fn host(&mut self, host: impl AsRef) -> &mut Self { - let host = host.as_ref(); - - self.address = if host.starts_with('/') { - Either::Right(PathBuf::from(&*host)) - } else { - Either::Left((host.into(), self.get_port())) - }; - - self - } - - /// Sets the path of the Unix domain socket to connect to. - /// - /// Overrides [`host()`](#method.host) and [`port()`](#method.port). - /// - pub fn socket(&mut self, socket: impl AsRef) -> &mut Self { - self.address = Either::Right(socket.as_ref().to_owned()); - self - } - - /// Sets the TCP port number of the database server. - /// - /// Defaults to `3306`. - /// - pub fn port(&mut self, port: u16) -> &mut Self { - self.address = match self.address { - Either::Right(_) => Either::Left(("localhost".to_owned(), port)), - Either::Left((ref mut host, _)) => Either::Left((mem::take(host), port)), - }; - - self - } - - /// Sets the username to be used for authentication. - // FIXME: Specify what happens when you do NOT set this - pub fn username(&mut self, username: impl AsRef) -> &mut Self { - self.username = Some(username.as_ref().to_owned()); - self - } - - /// Sets the password to be used for authentication. - pub fn password(&mut self, password: impl AsRef) -> &mut Self { - self.password = Some(password.as_ref().to_owned()); - self - } - - /// Sets the default database for the connection. - pub fn database(&mut self, database: impl AsRef) -> &mut Self { - self.database = Some(database.as_ref().to_owned()); - self - } - - /// Sets the character set for the connection. - pub fn charset(&mut self, charset: impl AsRef) -> &mut Self { - self.charset = charset.as_ref().to_owned(); - self - } - - /// Sets the timezone for the connection. - pub fn timezone(&mut self, timezone: impl AsRef) -> &mut Self { - self.timezone = timezone.as_ref().to_owned(); - self - } -} diff --git a/sqlx-postgres/src/options/default.rs b/sqlx-postgres/src/options/default.rs deleted file mode 100644 index 053cc20c..00000000 --- a/sqlx-postgres/src/options/default.rs +++ /dev/null @@ -1,38 +0,0 @@ -use std::marker::PhantomData; - -use either::Either; -use sqlx_core::Runtime; - -use crate::PostgresConnectOptions; - -pub(crate) const HOST: &str = "localhost"; -pub(crate) const PORT: u16 = 3306; - -impl Default for PostgresConnectOptions -where - Rt: Runtime, -{ - fn default() -> Self { - Self { - runtime: PhantomData, - address: Either::Left((HOST.to_owned(), PORT)), - username: None, - password: None, - database: None, - charset: "utf8mb4".to_owned(), - timezone: "utc".to_owned(), - // todo: connect_timeout - } - } -} - -impl super::PostgresConnectOptions -where - Rt: Runtime, -{ - /// Creates a default set of options ready for configuration. - #[must_use] - pub fn new() -> Self { - Self::default() - } -} diff --git a/sqlx-postgres/src/options/getters.rs b/sqlx-postgres/src/options/getters.rs deleted file mode 100644 index b3e7f6f9..00000000 --- a/sqlx-postgres/src/options/getters.rs +++ /dev/null @@ -1,55 +0,0 @@ -use std::path::{Path, PathBuf}; - -use sqlx_core::Runtime; - -use super::{default, PostgresConnectOptions}; - -impl PostgresConnectOptions { - /// Returns the hostname of the database server. - #[must_use] - pub fn get_host(&self) -> &str { - self.address.as_ref().left().map_or(default::HOST, |(host, _)| &**host) - } - - /// Returns the TCP port number of the database server. - #[must_use] - pub fn get_port(&self) -> u16 { - self.address.as_ref().left().map_or(default::PORT, |(_, port)| *port) - } - - /// Returns the path to the Unix domain socket, if one is configured. - #[must_use] - pub fn get_socket(&self) -> Option<&Path> { - self.address.as_ref().right().map(PathBuf::as_path) - } - - /// Returns the default database name. - #[must_use] - pub fn get_database(&self) -> Option<&str> { - self.database.as_deref() - } - - /// Returns the username to be used for authentication. - #[must_use] - pub fn get_username(&self) -> Option<&str> { - self.username.as_deref() - } - - /// Returns the password to be used for authentication. - #[must_use] - pub fn get_password(&self) -> Option<&str> { - self.password.as_deref() - } - - /// Returns the character set for the connection. - #[must_use] - pub fn get_charset(&self) -> &str { - &self.charset - } - - /// Returns the timezone for the connection. - #[must_use] - pub fn get_timezone(&self) -> &str { - &self.timezone - } -} diff --git a/sqlx-postgres/src/options/parse.rs b/sqlx-postgres/src/options/parse.rs deleted file mode 100644 index 2c395e82..00000000 --- a/sqlx-postgres/src/options/parse.rs +++ /dev/null @@ -1,183 +0,0 @@ -use std::borrow::Cow; -use std::str::FromStr; - -use percent_encoding::percent_decode_str; -use sqlx_core::{Error, Runtime}; -use url::Url; - -use crate::PostgresConnectOptions; - -impl FromStr for PostgresConnectOptions -where - Rt: Runtime, -{ - type Err = Error; - - fn from_str(s: &str) -> Result { - let url: Url = - s.parse().map_err(|error| Error::configuration("for database URL", error))?; - - if !matches!(url.scheme(), "postgres") { - return Err(Error::configuration_msg(format!( - "unsupported URL scheme {:?} for MySQL", - url.scheme() - ))); - } - - let mut options = Self::new(); - - if let Some(host) = url.host_str() { - options.host(percent_decode_str_utf8(host)); - } - - if let Some(port) = url.port() { - options.port(port); - } - - let username = url.username(); - if !username.is_empty() { - options.username(percent_decode_str_utf8(username)); - } - - if let Some(password) = url.password() { - options.password(percent_decode_str_utf8(password)); - } - - let mut path = url.path(); - - if path.starts_with('/') { - path = &path[1..]; - } - - if !path.is_empty() { - options.database(path); - } - - for (key, value) in url.query_pairs() { - match &*key { - "user" | "username" => { - options.username(value); - } - - "password" => { - options.password(value); - } - - // ssl-mode compatibly with SQLx <= 0.5 - // sslmode compatibly with PostgreSQL - // sslMode compatibly with JDBC MySQL - // tls compatibly with Go MySQL [preferred] - "ssl-mode" | "sslmode" | "sslMode" | "tls" => { - todo!() - } - - "charset" => { - options.charset(value); - } - - "timezone" => { - options.timezone(value); - } - - "socket" => { - options.socket(&*value); - } - - _ => { - // ignore unknown connection parameters - // fixme: should we error or warn here? - } - } - } - - Ok(options) - } -} - -// todo: this should probably go somewhere common -fn percent_decode_str_utf8(value: &str) -> Cow<'_, str> { - percent_decode_str(value).decode_utf8_lossy() -} - -#[cfg(test)] -mod tests { - use std::path::Path; - - use sqlx_core::mock::Mock; - - use super::PostgresConnectOptions; - - #[test] - fn parse() { - let url = "postgres://user:password@hostname:5432/database?timezone=system&charset=utf8"; - let options: PostgresConnectOptions = url.parse().unwrap(); - - assert_eq!(options.get_username(), Some("user")); - assert_eq!(options.get_password(), Some("password")); - assert_eq!(options.get_host(), "hostname"); - assert_eq!(options.get_port(), 5432); - assert_eq!(options.get_database(), Some("database")); - assert_eq!(options.get_timezone(), "system"); - assert_eq!(options.get_charset(), "utf8"); - } - - #[test] - fn parse_with_defaults() { - let url = "postgres://"; - let options: PostgresConnectOptions = url.parse().unwrap(); - - assert_eq!(options.get_username(), None); - assert_eq!(options.get_password(), None); - assert_eq!(options.get_host(), "localhost"); - assert_eq!(options.get_port(), 3306); - assert_eq!(options.get_database(), None); - assert_eq!(options.get_timezone(), "utc"); - assert_eq!(options.get_charset(), "utf8mb4"); - } - - #[test] - fn parse_socket_from_query() { - let url = "postgres://user:password@localhost/database?socket=/var/run/postgresd/postgresd.sock"; - let options: PostgresConnectOptions = url.parse().unwrap(); - - assert_eq!(options.get_username(), Some("user")); - assert_eq!(options.get_password(), Some("password")); - assert_eq!(options.get_database(), Some("database")); - assert_eq!(options.get_socket(), Some(Path::new("/var/run/postgresd/postgresd.sock"))); - } - - #[test] - fn parse_socket_from_host() { - // socket path in host requires URL encoding - but does work - let url = "postgres://user:password@%2Fvar%2Frun%2Fpostgresd%2Fpostgresd.sock/database"; - let options: PostgresConnectOptions = url.parse().unwrap(); - - assert_eq!(options.get_username(), Some("user")); - assert_eq!(options.get_password(), Some("password")); - assert_eq!(options.get_database(), Some("database")); - assert_eq!(options.get_socket(), Some(Path::new("/var/run/postgresd/postgresd.sock"))); - } - - #[test] - #[should_panic] - fn fail_to_parse_non_postgres() { - let url = "postgres://user:password@hostname:5432/database?timezone=system&charset=utf8"; - let _: PostgresConnectOptions = url.parse().unwrap(); - } - - #[test] - fn parse_username_with_at_sign() { - let url = "postgres://user@hostname:password@hostname:5432/database"; - let options: PostgresConnectOptions = url.parse().unwrap(); - - assert_eq!(options.get_username(), Some("user@hostname")); - } - - #[test] - fn parse_password_with_non_ascii_chars() { - let url = "postgres://username:p@ssw0rd@hostname:5432/database"; - let options: PostgresConnectOptions = url.parse().unwrap(); - - assert_eq!(options.get_password(), Some("p@ssw0rd")); - } -} diff --git a/sqlx-postgres/src/output.rs b/sqlx-postgres/src/output.rs new file mode 100644 index 00000000..c7f57938 --- /dev/null +++ b/sqlx-postgres/src/output.rs @@ -0,0 +1,15 @@ +// 'x: execution +#[allow(clippy::module_name_repetitions)] +pub struct PgOutput<'x> { + buffer: &'x mut Vec, +} + +impl<'x> PgOutput<'x> { + pub(crate) fn new(buffer: &'x mut Vec) -> Self { + Self { buffer } + } + + pub(crate) fn buffer(&mut self) -> &mut Vec { + self.buffer + } +} diff --git a/sqlx-postgres/src/protocol.rs b/sqlx-postgres/src/protocol.rs index 2195504e..3ef8d5de 100644 --- a/sqlx-postgres/src/protocol.rs +++ b/sqlx-postgres/src/protocol.rs @@ -1,111 +1,3 @@ -use std::convert::TryFrom; +mod message; -use bytes::{Buf, Bytes}; -use sqlx_core::io::Deserialize; -use sqlx_core::Error; -use sqlx_core::Result; - -mod authentication; -mod backend_key_data; -mod close; -mod notification; -mod password; -mod ready_for_query; -mod response; -mod sasl; -mod startup; -mod terminate; - -pub(crate) use authentication::{Authentication, AuthenticationSasl}; -pub(crate) use backend_key_data::BackendKeyData; -pub(crate) use close::Close; -pub(crate) use notification::Notification; -pub(crate) use password::Password; -pub(crate) use ready_for_query::ReadyForQuery; -pub(crate) use response::{Notice, PgSeverity}; -pub(crate) use sasl::{SaslInitialResponse, SaslResponse}; -pub(crate) use startup::Startup; -pub(crate) use terminate::Terminate; - -#[derive(Debug, Clone, Copy, PartialOrd, PartialEq, Eq)] -#[repr(u8)] -pub enum MessageType { - ParseComplete = b'1', - BindComplete = b'2', - CloseComplete = b'3', - CommandComplete = b'C', - DataRow = b'D', - ErrorResponse = b'E', - EmptyQueryResponse = b'I', - NotificationResponse = b'A', - BackendKeyData = b'K', - NoticeResponse = b'N', - Authentication = b'R', - ParameterStatus = b'S', - RowDescription = b'T', - ReadyForQuery = b'Z', - NoData = b'n', - PortalSuspended = b's', - ParameterDescription = b't', -} - -#[derive(Debug)] -pub struct Message { - pub r#type: MessageType, - pub contents: Bytes, -} - -impl Message { - #[inline] - pub fn decode<'de, T>(self) -> Result - where - T: Deserialize<'de, ()>, - { - T::deserialize_with(self.contents, ()) - } -} - -impl TryFrom for MessageType { - type Error = Error; - - fn try_from(v: u8) -> Result { - // https://www.postgresql.org/docs/current/protocol-message-formats.html - - Ok(match v { - b'1' => MessageType::ParseComplete, - b'2' => MessageType::BindComplete, - b'3' => MessageType::CloseComplete, - b'C' => MessageType::CommandComplete, - b'D' => MessageType::DataRow, - b'E' => MessageType::ErrorResponse, - b'I' => MessageType::EmptyQueryResponse, - b'A' => MessageType::NotificationResponse, - b'K' => MessageType::BackendKeyData, - b'N' => MessageType::NoticeResponse, - b'R' => MessageType::Authentication, - b'S' => MessageType::ParameterStatus, - b'T' => MessageType::RowDescription, - b'Z' => MessageType::ReadyForQuery, - b'n' => MessageType::NoData, - b's' => MessageType::PortalSuspended, - b't' => MessageType::ParameterDescription, - - _ => { - return Err(Error::configuration_msg(format!( - "unknown message type: {:?}", - v as char - ))); - } - }) - } -} - -impl Deserialize<'_, ()> for Message { - fn deserialize_with(mut buf: Bytes, _: ()) -> Result { - let r#type = MessageType::try_from(buf.get_u8())?; - let size = buf.get_u32() - 4; - let contents = buf.split_to(size as usize); - - Ok(Message { r#type, contents }) - } -} +pub(crate) use message::{BackendMessage, BackendMessageType}; diff --git a/sqlx-postgres/src/protocol/authentication.rs b/sqlx-postgres/src/protocol/authentication.rs deleted file mode 100644 index 8479c9d4..00000000 --- a/sqlx-postgres/src/protocol/authentication.rs +++ /dev/null @@ -1,204 +0,0 @@ -use bytes::{Buf, Bytes}; -use memchr::memchr; -use sqlx_core::io::Deserialize; -use sqlx_core::Error; -use sqlx_core::Result; - -// On startup, the server sends an appropriate authentication request message, -// to which the frontend must reply with an appropriate authentication -// response message (such as a password). - -// For all authentication methods except GSSAPI, SSPI and SASL, there is at -// most one request and one response. In some methods, no response at all is -// needed from the frontend, and so no authentication request occurs. - -// For GSSAPI, SSPI and SASL, multiple exchanges of packets may -// be needed to complete the authentication. - -// -// - -#[derive(Debug)] -pub enum Authentication { - /// The authentication exchange is successfully completed. - Ok, - - /// The frontend must now send a [PasswordMessage] containing the - /// password in clear-text form. - CleartextPassword, - - /// The frontend must now send a [PasswordMessage] containing the - /// password (with user name) encrypted via MD5, then encrypted - /// again using the 4-byte random salt. - Md5Password(AuthenticationMd5Password), - - /// The frontend must now initiate a SASL negotiation, - /// using one of the SASL mechanisms listed in the message. - /// - /// The frontend will send a [SaslInitialResponse] with the name - /// of the selected mechanism, and the first part of the SASL - /// data stream in response to this. - /// - /// If further messages are needed, the server will - /// respond with [Authentication::SaslContinue]. - Sasl(AuthenticationSasl), - - /// This message contains challenge data from the previous step of SASL negotiation. - /// - /// The frontend must respond with a [SaslResponse] message. - SaslContinue(AuthenticationSaslContinue), - - /// SASL authentication has completed with additional mechanism-specific - /// data for the client. - /// - /// The server will next send [Authentication::Ok] to - /// indicate successful authentication. - SaslFinal(AuthenticationSaslFinal), -} - -impl Deserialize<'_, ()> for Authentication { - fn deserialize_with(mut buf: Bytes, _: ()) -> Result { - Ok(match buf.get_u32() { - 0 => Authentication::Ok, - - 3 => Authentication::CleartextPassword, - - 5 => { - let mut salt = [0; 4]; - buf.copy_to_slice(&mut salt); - - Authentication::Md5Password(AuthenticationMd5Password { salt }) - } - - 10 => Authentication::Sasl(AuthenticationSasl(buf)), - - 11 => { - Authentication::SaslContinue(AuthenticationSaslContinue::deserialize_with(buf, ())?) - } - - 12 => Authentication::SaslFinal(AuthenticationSaslFinal::deserialize_with(buf, ())?), - - ty => { - return Err(Error::configuration_msg(format!( - "unknown authentication method: {}", - ty - ))); - } - }) - } -} - -/// Body of [Authentication::Md5Password]. -#[derive(Debug)] -pub struct AuthenticationMd5Password { - pub salt: [u8; 4], -} - -/// Body of [Authentication::Sasl]. -#[derive(Debug)] -pub struct AuthenticationSasl(Bytes); - -impl AuthenticationSasl { - #[inline] - pub fn mechanisms(&self) -> SaslMechanisms<'_> { - SaslMechanisms(&self.0) - } -} - -/// An iterator over the SASL authentication mechanisms provided by the server. -pub struct SaslMechanisms<'a>(&'a [u8]); - -impl<'a> Iterator for SaslMechanisms<'a> { - type Item = &'a str; - - fn next(&mut self) -> Option { - if !self.0.is_empty() && self.0[0] == b'\0' { - return None; - } - - #[allow(unsafe_code)] - let mechanism = memchr(b'\0', self.0) - // UNSAFE: Postgres is expecte to return a valid UTF-8 string here - .and_then(|nul| Some(unsafe { std::str::from_utf8_unchecked(&self.0[..nul]) }))?; - - self.0 = &self.0[(mechanism.len() + 1)..]; - - Some(mechanism) - } -} - -#[derive(Debug)] -pub struct AuthenticationSaslContinue { - pub salt: Vec, - pub iterations: u32, - pub nonce: String, - pub message: String, -} - -impl Deserialize<'_, ()> for AuthenticationSaslContinue { - fn deserialize_with(mut buf: Bytes, _: ()) -> Result { - let mut iterations: u32 = 4096; - let mut salt = Vec::new(); - let mut nonce = Bytes::new(); - - // [Example] - // r=/z+giZiTxAH7r8sNAeHr7cvpqV3uo7G/bJBIJO3pjVM7t3ng,s=4UV68bIkC8f9/X8xH7aPhg==,i=4096 - - for item in buf.split(|b| *b == b',') { - let key = item[0]; - let value = &item[2..]; - - match key { - b'r' => { - nonce = buf.slice_ref(value); - } - - b'i' => { - iterations = atoi::atoi(value).unwrap_or(4096); - } - - b's' => { - // TODO: Map error correctly - salt = base64::decode(value).unwrap(); - } - - _ => {} - } - } - - #[allow(unsafe_code)] - Ok(Self { - iterations, - salt, - - // UNSAFE: Postgres is expected to return a valid UTF-8 string here - nonce: unsafe { String::from_utf8_unchecked((*nonce).to_vec()) }, - - // UNSAFE: Postgres is expected to return a valid UTF-8 string here - message: unsafe { String::from_utf8_unchecked((*buf).to_vec()) }, - }) - } -} - -#[derive(Debug)] -pub struct AuthenticationSaslFinal { - pub verifier: Vec, -} - -impl Deserialize<'_, ()> for AuthenticationSaslFinal { - fn deserialize_with(mut buf: Bytes, _: ()) -> Result { - let mut verifier = Vec::new(); - - for item in buf.split(|b| *b == b',') { - let key = item[0]; - let value = &item[2..]; - - if let b'v' = key { - // TODO: Map error correctly - verifier = base64::decode(value).unwrap(); - } - } - - Ok(Self { verifier }) - } -} diff --git a/sqlx-postgres/src/protocol/backend_key_data.rs b/sqlx-postgres/src/protocol/backend_key_data.rs deleted file mode 100644 index 0e9ebbd9..00000000 --- a/sqlx-postgres/src/protocol/backend_key_data.rs +++ /dev/null @@ -1,25 +0,0 @@ -use byteorder::{BigEndian, ByteOrder}; -use bytes::Bytes; -use sqlx_core::io::Deserialize; -use sqlx_core::Error; -use sqlx_core::Result; - -/// Contains cancellation key data. The frontend must save these values if it -/// wishes to be able to issue `CancelRequest` messages later. -#[derive(Debug)] -pub struct BackendKeyData { - /// The process ID of this database. - pub process_id: u32, - - /// The secret key of this database. - pub secret_key: u32, -} - -impl Deserialize<'_, ()> for BackendKeyData { - fn deserialize_with(mut buf: Bytes, _: ()) -> Result { - let process_id = BigEndian::read_u32(&buf); - let secret_key = BigEndian::read_u32(&buf[4..]); - - Ok(Self { process_id, secret_key }) - } -} diff --git a/sqlx-postgres/src/protocol/close.rs b/sqlx-postgres/src/protocol/close.rs deleted file mode 100644 index 95e916b9..00000000 --- a/sqlx-postgres/src/protocol/close.rs +++ /dev/null @@ -1,36 +0,0 @@ -use sqlx_core::io::Serialize; -use sqlx_core::Result; - -use crate::io::PgBufMutExt; - -const CLOSE_PORTAL: u8 = b'P'; -const CLOSE_STATEMENT: u8 = b'S'; - -#[derive(Debug)] -#[allow(dead_code)] -pub enum Close { - Statement(u32), - Portal(u32), -} - -impl Serialize<'_, ()> for Close { - fn serialize_with(&self, buf: &mut Vec, _: ()) -> Result<()> { - // 15 bytes for 1-digit statement/portal IDs - buf.reserve(20); - buf.push(b'C'); - - buf.write_length_prefixed(|buf| match self { - Close::Statement(id) => { - buf.push(CLOSE_STATEMENT); - buf.write_statement_name(*id); - } - - Close::Portal(id) => { - buf.push(CLOSE_PORTAL); - buf.write_portal_name(Some(*id)); - } - }); - - Ok(()) - } -} diff --git a/sqlx-postgres/src/protocol/flush.rs b/sqlx-postgres/src/protocol/flush.rs deleted file mode 100644 index 6967a1e5..00000000 --- a/sqlx-postgres/src/protocol/flush.rs +++ /dev/null @@ -1,18 +0,0 @@ -use sqlx_core::io::Serialize; -use sqlx_core::Result; - -// The Flush message does not cause any specific output to be generated, -// but forces the backend to deliver any data pending in its output buffers. - -// A Flush must be sent after any extended-query command except Sync, if the -// frontend wishes to examine the results of that command before issuing more commands. - -#[derive(Debug)] -pub struct Flush; - -impl Serialize<'_, ()> for Flush { - fn serialize_with(&self, buf: &mut Vec, _: ()) -> Result<()> { - buf.push(b'H'); - buf.extend(&4_i32.to_be_bytes()); - } -} diff --git a/sqlx-postgres/src/protocol/message.rs b/sqlx-postgres/src/protocol/message.rs new file mode 100644 index 00000000..458eb7e7 --- /dev/null +++ b/sqlx-postgres/src/protocol/message.rs @@ -0,0 +1,93 @@ +use std::convert::TryFrom; +use std::fmt::Debug; + +use bytes::Bytes; +use sqlx_core::io::Deserialize; +use sqlx_core::{Error, Result}; + +/// Type of the *incoming* message. +/// +/// Postgres does use the same message format for client and server messages but we are only +/// interested in messages from the backend. +/// +#[derive(Debug, Clone, Copy)] +#[repr(u8)] +pub(crate) enum BackendMessageType { + ParseComplete = b'1', + BindComplete = b'2', + CloseComplete = b'3', + CommandComplete = b'C', + DataRow = b'D', + ErrorResponse = b'E', + EmptyQueryResponse = b'I', + NotificationResponse = b'A', + BackendKeyData = b'K', + NoticeResponse = b'N', + Authentication = b'R', + ParameterStatus = b'S', + RowDescription = b'T', + ReadyForQuery = b'Z', + NoData = b'n', + PortalSuspended = b's', + ParameterDescription = b't', + CopyInResponse = b'G', + CopyOutResponse = b'H', + CopyBothResponse = b'W', + CopyData = b'd', + CopyDone = b'c', +} + +impl TryFrom for BackendMessageType { + type Error = Error; + + fn try_from(ty: u8) -> Result { + Ok(match ty { + b'1' => Self::ParseComplete, + b'2' => Self::BindComplete, + b'3' => Self::CloseComplete, + b'C' => Self::CommandComplete, + b'D' => Self::DataRow, + b'E' => Self::ErrorResponse, + b'I' => Self::EmptyQueryResponse, + b'A' => Self::NotificationResponse, + b'K' => Self::BackendKeyData, + b'N' => Self::NoticeResponse, + b'R' => Self::Authentication, + b'S' => Self::ParameterStatus, + b'T' => Self::RowDescription, + b'Z' => Self::ReadyForQuery, + b'n' => Self::NoData, + b's' => Self::PortalSuspended, + b't' => Self::ParameterDescription, + b'G' => Self::CopyInResponse, + b'H' => Self::CopyOutResponse, + b'W' => Self::CopyBothResponse, + b'd' => Self::CopyData, + b'c' => Self::CopyDone, + + _ => { + todo!("protocol unexpected data error") + } + }) + } +} + +#[derive(Debug)] +pub(crate) struct BackendMessage { + pub(crate) r#type: BackendMessageType, + pub(crate) contents: Bytes, +} + +impl BackendMessage { + #[inline] + pub(crate) fn deserialize<'de, T>(self) -> Result + where + T: Deserialize<'de> + Debug, + { + let packet = T::deserialize(self.contents)?; + + log::trace!("read > {:?}", packet); + + Ok(packet) + } +} diff --git a/sqlx-postgres/src/protocol/notification.rs b/sqlx-postgres/src/protocol/notification.rs deleted file mode 100644 index db200fee..00000000 --- a/sqlx-postgres/src/protocol/notification.rs +++ /dev/null @@ -1,28 +0,0 @@ -use bytes::{Buf, Bytes}; -use bytestring::ByteString; -use sqlx_core::io::BufExt; -use sqlx_core::io::Deserialize; -use sqlx_core::Result; - -#[derive(Debug)] -pub struct Notification { - pub(crate) process_id: u32, - pub(crate) channel: ByteString, - pub(crate) payload: ByteString, -} - -impl Deserialize<'_, ()> for Notification { - fn deserialize_with(mut buf: Bytes, _: ()) -> Result { - let process_id = buf.get_u32(); - - // UNSAFE: This message will not be read. - #[allow(unsafe_code)] - let channel = unsafe { buf.get_str_nul_unchecked()? }; - - // UNSAFE: This message will not be read. - #[allow(unsafe_code)] - let payload = unsafe { buf.get_str_nul_unchecked()? }; - - Ok(Self { process_id, channel, payload }) - } -} diff --git a/sqlx-postgres/src/protocol/password.rs b/sqlx-postgres/src/protocol/password.rs deleted file mode 100644 index 15fcea2b..00000000 --- a/sqlx-postgres/src/protocol/password.rs +++ /dev/null @@ -1,67 +0,0 @@ -use std::fmt::Write; - -use md5::{Digest, Md5}; -use sqlx_core::io::Serialize; -use sqlx_core::io::WriteExt; -use sqlx_core::Result; - -use crate::io::PgBufMutExt; - -#[derive(Debug)] -pub enum Password<'a> { - Cleartext(&'a str), - - Md5 { password: &'a str, username: &'a str, salt: [u8; 4] }, -} - -impl Password<'_> { - #[inline] - fn len(&self) -> usize { - match self { - Password::Cleartext(s) => s.len() + 5, - Password::Md5 { .. } => 35 + 5, - } - } -} - -impl Serialize<'_, ()> for Password<'_> { - fn serialize_with(&self, buf: &mut Vec, _: ()) -> Result<()> { - buf.reserve(1 + 4 + self.len()); - buf.push(b'p'); - - buf.write_length_prefixed(|buf| { - match self { - Password::Cleartext(password) => { - buf.write_str_nul(password); - } - - Password::Md5 { username, password, salt } => { - // The actual `PasswordMessage` can be comwriteed in SQL as - // `concat('md5', md5(concat(md5(concat(password, username)), random-salt)))`. - - // Keep in mind the md5() function returns its result as a hex string. - - let mut hasher = Md5::new(); - - hasher.update(password); - hasher.update(username); - - let mut output = String::with_capacity(35); - - let _ = write!(output, "{:x}", hasher.finalize_reset()); - - hasher.update(&output); - hasher.update(salt); - - output.clear(); - - let _ = write!(output, "md5{:x}", hasher.finalize()); - - buf.write_str_nul(&output); - } - } - }); - - Ok(()) - } -} diff --git a/sqlx-postgres/src/protocol/ready_for_query.rs b/sqlx-postgres/src/protocol/ready_for_query.rs deleted file mode 100644 index 81a4159c..00000000 --- a/sqlx-postgres/src/protocol/ready_for_query.rs +++ /dev/null @@ -1,51 +0,0 @@ -use std::convert::TryFrom; - -use bytes::Bytes; -use sqlx_core::io::Deserialize; -use sqlx_core::Error; -use sqlx_core::Result; - -#[derive(Debug)] -#[repr(u8)] -pub(crate) enum TransactionStatus { - /// Not in a transaction block. - Idle = b'I', - - /// In a transaction block. - Transaction = b'T', - - /// In a _failed_ transaction block. Queries will be rejected until block is ended. - Error = b'E', -} - -impl TryFrom for TransactionStatus { - type Error = Error; - - fn try_from(value: u8) -> Result { - match value { - b'I' => Ok(TransactionStatus::Idle), - b'T' => Ok(TransactionStatus::Transaction), - b'E' => Ok(TransactionStatus::Error), - - status => { - return Err(Error::configuration_msg(format!( - "unknown transaction status: {:?}", - status as char, - ))); - } - } - } -} - -#[derive(Debug)] -pub(crate) struct ReadyForQuery { - pub transaction_status: TransactionStatus, -} - -impl Deserialize<'_, ()> for ReadyForQuery { - fn deserialize_with(mut buf: Bytes, _: ()) -> Result { - let transaction_status = TransactionStatus::try_from(buf[0])?; - - Ok(Self { transaction_status }) - } -} diff --git a/sqlx-postgres/src/protocol/response.rs b/sqlx-postgres/src/protocol/response.rs deleted file mode 100644 index 93edf7f3..00000000 --- a/sqlx-postgres/src/protocol/response.rs +++ /dev/null @@ -1,190 +0,0 @@ -use std::str::from_utf8; - -use bytes::Bytes; -use memchr::memchr; -use sqlx_core::io::Deserialize; -use sqlx_core::Error; -use sqlx_core::Result; - -#[derive(Debug, Copy, Clone, Eq, PartialEq)] -#[repr(u8)] -pub enum PgSeverity { - Panic, - Fatal, - Error, - Warning, - Notice, - Debug, - Info, - Log, -} - -impl PgSeverity { - #[inline] - pub fn is_error(self) -> bool { - matches!(self, Self::Panic | Self::Fatal | Self::Error) - } -} - -impl std::convert::TryFrom<&str> for PgSeverity { - type Error = Error; - - fn try_from(s: &str) -> Result { - let result = match s { - "PANIC" => PgSeverity::Panic, - "FATAL" => PgSeverity::Fatal, - "ERROR" => PgSeverity::Error, - "WARNING" => PgSeverity::Warning, - "NOTICE" => PgSeverity::Notice, - "DEBUG" => PgSeverity::Debug, - "INFO" => PgSeverity::Info, - "LOG" => PgSeverity::Log, - - severity => { - return Err(Error::configuration_msg(format!("unknown severity: {:?}", severity))); - } - }; - - Ok(result) - } -} - -#[derive(Debug)] -pub struct Notice { - storage: Bytes, - severity: PgSeverity, - message: (u16, u16), - code: (u16, u16), -} - -impl Notice { - #[inline] - pub fn severity(&self) -> PgSeverity { - self.severity - } - - #[inline] - pub fn code(&self) -> &str { - self.get_cached_str(self.code) - } - - #[inline] - pub fn message(&self) -> &str { - self.get_cached_str(self.message) - } - - // Field descriptions available here: - // https://www.postgresql.org/docs/current/protocol-error-fields.html - - #[inline] - pub fn get(&self, ty: u8) -> Option<&str> { - self.get_raw(ty).and_then(|v| from_utf8(v).ok()) - } - - pub fn get_raw(&self, ty: u8) -> Option<&[u8]> { - self.fields() - .filter(|(field, _)| *field == ty) - .map(|(_, (start, end))| &self.storage[start as usize..end as usize]) - .next() - } -} - -impl Notice { - #[inline] - fn fields(&self) -> Fields<'_> { - Fields { storage: &self.storage, offset: 0 } - } - - #[inline] - fn get_cached_str(&self, cache: (u16, u16)) -> &str { - // unwrap: this cannot fail at this stage - from_utf8(&self.storage[cache.0 as usize..cache.1 as usize]).unwrap() - } -} - -impl Deserialize<'_, ()> for Notice { - fn deserialize_with(mut buf: Bytes, _: ()) -> Result { - // In order to support PostgreSQL 9.5 and older we need to parse the localized S field. - // Newer versions additionally come with the V field that is guaranteed to be in English. - // We thus read both versions and prefer the unlocalized one if available. - const DEFAULT_SEVERITY: PgSeverity = PgSeverity::Log; - let mut severity_v = None; - let mut severity_s = None; - let mut message = (0, 0); - let mut code = (0, 0); - - // we cache the three always present fields - // this enables to keep the access time down for the fields most likely accessed - - let fields = Fields { storage: &buf, offset: 0 }; - - for (field, v) in fields { - if message.0 != 0 && code.0 != 0 { - // stop iterating when we have the 3 fields we were looking for - // we assume V (severity) was the first field as it should be - break; - } - - use std::convert::TryInto; - match field { - b'S' => { - // Discard potential errors, because the message might be localized - severity_s = - from_utf8(&buf[v.0 as usize..v.1 as usize]).unwrap().try_into().ok(); - } - - b'V' => { - // Propagate errors here, because V is not localized and thus we are missing a possible - // variant. - severity_v = - Some(from_utf8(&buf[v.0 as usize..v.1 as usize]).unwrap().try_into()?); - } - - b'M' => { - message = v; - } - - b'C' => { - code = v; - } - - _ => {} - } - } - - Ok(Self { - severity: severity_v.or(severity_s).unwrap_or(DEFAULT_SEVERITY), - message, - code, - storage: buf, - }) - } -} - -/// An iterator over each field in the Error (or Notice) response. -struct Fields<'a> { - storage: &'a [u8], - offset: u16, -} - -impl<'a> Iterator for Fields<'a> { - type Item = (u8, (u16, u16)); - - fn next(&mut self) -> Option { - // The fields in the response body are sequentially stored as [tag][string], - // ending in a final, additional [nul] - - let ty = self.storage[self.offset as usize]; - - if ty == 0 { - return None; - } - - let nul = memchr(b'\0', &self.storage[(self.offset + 1) as usize..])? as u16; - let offset = self.offset; - - self.offset += nul + 2; - - Some((ty, (offset + 1, offset + nul + 1))) - } -} diff --git a/sqlx-postgres/src/protocol/sasl.rs b/sqlx-postgres/src/protocol/sasl.rs deleted file mode 100644 index 1f650fd3..00000000 --- a/sqlx-postgres/src/protocol/sasl.rs +++ /dev/null @@ -1,40 +0,0 @@ -use sqlx_core::io::Serialize; -use sqlx_core::io::WriteExt; -use sqlx_core::Result; - -use crate::io::PgBufMutExt; - -#[derive(Debug)] -pub struct SaslInitialResponse { - pub response: String, - pub plus: bool, -} - -impl Serialize<'_, ()> for SaslInitialResponse { - fn serialize_with(&self, buf: &mut Vec, _: ()) -> Result<()> { - buf.push(b'p'); - buf.write_length_prefixed(|buf| { - // name of the SASL authentication mechanism that the client selected - buf.write_str_nul(if self.plus { "SCRAM-SHA-256-PLUS" } else { "SCRAM-SHA-256" }); - - buf.extend(&(self.response.as_bytes().len() as i32).to_be_bytes()); - buf.extend(self.response.as_bytes()); - }); - - Ok(()) - } -} - -#[derive(Debug)] -pub struct SaslResponse<'a>(pub &'a str); - -impl Serialize<'_, ()> for SaslResponse<'_> { - fn serialize_with(&self, buf: &mut Vec, _: ()) -> Result<()> { - buf.push(b'p'); - buf.write_length_prefixed(|buf| { - buf.extend(self.0.as_bytes()); - }); - - Ok(()) - } -} diff --git a/sqlx-postgres/src/protocol/startup.rs b/sqlx-postgres/src/protocol/startup.rs deleted file mode 100644 index 16433e71..00000000 --- a/sqlx-postgres/src/protocol/startup.rs +++ /dev/null @@ -1,64 +0,0 @@ -use sqlx_core::io::Serialize; -use sqlx_core::io::WriteExt; -use sqlx_core::Result; - -use crate::io::PgBufMutExt; - -// To begin a session, a frontend opens a connection to the server and sends a startup message. -// This message includes the names of the user and of the database the user wants to connect to; -// it also identifies the particular protocol version to be used. - -// Optionally, the startup message can include additional settings for run-time parameters. - -#[derive(Debug)] -pub struct Startup<'a> { - /// The database user name to connect as. Required; there is no default. - pub username: Option<&'a str>, - - /// The database to connect to. Defaults to the user name. - pub database: Option<&'a str>, - - /// Additional start-up params. - /// - pub params: &'a [(&'a str, &'a str)], -} - -impl Serialize<'_, ()> for Startup<'_> { - fn serialize_with(&self, buf: &mut Vec, _: ()) -> Result<()> { - buf.reserve(120); - - buf.write_length_prefixed(|buf| { - // The protocol version number. The most significant 16 bits are the - // major version number (3 for the protocol described here). The least - // significant 16 bits are the minor version number (0 - // for the protocol described here) - buf.extend(&196_608_i32.to_be_bytes()); - - if let Some(username) = self.username { - // The database user name to connect as. - encode_startup_param(buf, "user", username); - } - - if let Some(database) = self.database { - // The database to connect to. Defaults to the user name. - encode_startup_param(buf, "database", database); - } - - for (name, value) in self.params { - encode_startup_param(buf, name, value); - } - - // A zero byte is required as a terminator - // after the last name/value pair. - buf.push(0); - }); - - Ok(()) - } -} - -#[inline] -fn encode_startup_param(buf: &mut Vec, name: &str, value: &str) { - buf.write_str_nul(name); - buf.write_str_nul(value); -} diff --git a/sqlx-postgres/src/protocol/terminate.rs b/sqlx-postgres/src/protocol/terminate.rs deleted file mode 100644 index 8b5c29e3..00000000 --- a/sqlx-postgres/src/protocol/terminate.rs +++ /dev/null @@ -1,14 +0,0 @@ -use sqlx_core::io::Serialize; -use sqlx_core::Result; - -#[derive(Debug)] -pub struct Terminate; - -impl Serialize<'_, ()> for Terminate { - fn serialize_with(&self, buf: &mut Vec, _: ()) -> Result<()> { - buf.push(b'X'); - buf.extend(&4_u32.to_be_bytes()); - - Ok(()) - } -} diff --git a/sqlx-postgres/src/query_result.rs b/sqlx-postgres/src/query_result.rs new file mode 100644 index 00000000..90b7837a --- /dev/null +++ b/sqlx-postgres/src/query_result.rs @@ -0,0 +1,77 @@ +use std::convert::TryInto; +use std::fmt::{self, Debug, Formatter}; +use std::str::Utf8Error; + +use bytes::Bytes; +use bytestring::ByteString; +use memchr::memrchr; +use sqlx_core::QueryResult; + +// TODO: add unit tests for command tag parsing + +/// Represents the execution result of a command in Postgres. +/// +/// Returned from [`execute()`][sqlx_core::Executor::execute]. +/// +#[allow(clippy::module_name_repetitions)] +#[derive(Clone)] +pub struct PgQueryResult { + command: ByteString, + rows_affected: u64, +} + +impl PgQueryResult { + pub(crate) fn parse(mut command: Bytes) -> Result { + // look backwards for the first SPACE + let offset = memrchr(b' ', &command); + + let rows = if let Some(offset) = offset { + atoi::atoi(&command.split_off(offset).slice(1..)).unwrap_or(0) + } else { + 0 + }; + + Ok(Self { command: command.try_into()?, rows_affected: rows }) + } + + /// Returns the command tag. + /// + /// This is usually a single word that identifies which SQL command + /// was completed (e.g.,`INSERT`, `UPDATE`, or `MOVE`). + #[must_use] + pub fn command(&self) -> &str { + &self.command + } + + /// Returns the number of rows inserted, deleted, updated, retrieved, + /// changed, or copied by the SQL command. + #[must_use] + pub const fn rows_affected(&self) -> u64 { + self.rows_affected + } +} + +impl Debug for PgQueryResult { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("PgQueryResult") + .field("command", &self.command()) + .field("rows_affected", &self.rows_affected()) + .finish() + } +} + +impl Extend for PgQueryResult { + fn extend>(&mut self, iter: T) { + for res in iter { + self.rows_affected += res.rows_affected; + self.command = res.command; + } + } +} + +impl QueryResult for PgQueryResult { + #[inline] + fn rows_affected(&self) -> u64 { + self.rows_affected() + } +} diff --git a/sqlx-postgres/src/raw_value.rs b/sqlx-postgres/src/raw_value.rs new file mode 100644 index 00000000..b5dcf540 --- /dev/null +++ b/sqlx-postgres/src/raw_value.rs @@ -0,0 +1,55 @@ +use bytes::Bytes; +use sqlx_core::RawValue; + +use crate::{PgTypeInfo, Postgres}; + +/// The format of a raw SQL value for Postgres. +/// +/// Postgres returns values in [`Text`] or [`Binary`] format with a +/// configuration option in a prepared query. SQLx currently hard-codes that +/// option to [`Binary`]. +/// +/// For simple queries, postgres only can return values in [`Text`] format. +/// +#[derive(Debug, PartialEq, Copy, Clone)] +pub enum PgRawValueFormat { + Binary, + Text, +} + +/// The raw representation of a SQL value for Postgres. +// 'r: row +#[derive(Debug, Clone)] +#[allow(clippy::module_name_repetitions)] +pub struct PgRawValue<'r> { + value: Option<&'r Bytes>, + format: PgRawValueFormat, + type_info: PgTypeInfo, +} + +// 'r: row +impl<'r> PgRawValue<'r> { + /// Returns the type information for this value. + #[must_use] + pub const fn type_info(&self) -> &PgTypeInfo { + &self.type_info + } + + /// Returns the format of this value. + #[must_use] + pub const fn format(&self) -> PgRawValueFormat { + self.format + } +} + +impl<'r> RawValue<'r> for PgRawValue<'r> { + type Database = Postgres; + + fn is_null(&self) -> bool { + self.value.is_none() + } + + fn type_info(&self) -> &PgTypeInfo { + &self.type_info + } +} diff --git a/sqlx-postgres/src/row.rs b/sqlx-postgres/src/row.rs new file mode 100644 index 00000000..673fc0cc --- /dev/null +++ b/sqlx-postgres/src/row.rs @@ -0,0 +1,47 @@ +use sqlx_core::{ColumnIndex, Result, Row}; + +use crate::{PgColumn, PgRawValue, Postgres}; + +/// A single row from a result set generated from MySQL. +#[allow(clippy::module_name_repetitions)] +pub struct PgRow {} + +impl Row for PgRow { + type Database = Postgres; + + fn is_null(&self) -> bool { + // self.is_null() + todo!() + } + + fn len(&self) -> usize { + // self.len() + todo!() + } + + fn columns(&self) -> &[PgColumn] { + // self.columns() + todo!() + } + + fn try_column>(&self, index: I) -> Result<&PgColumn> { + // self.try_column(index) + todo!() + } + + fn column_name(&self, index: usize) -> Option<&str> { + // self.columns.get(index).map(PgColumn::name) + todo!() + } + + fn column_index(&self, name: &str) -> Option { + // self.columns.iter().position(|col| col.name() == name) + todo!() + } + + #[allow(clippy::needless_lifetimes)] + fn try_get_raw<'r, I: ColumnIndex>(&'r self, index: I) -> Result> { + // self.try_get_raw(index) + todo!() + } +} diff --git a/sqlx-postgres/src/stream.rs b/sqlx-postgres/src/stream.rs new file mode 100644 index 00000000..4bf3e5f3 --- /dev/null +++ b/sqlx-postgres/src/stream.rs @@ -0,0 +1,161 @@ +use std::convert::TryInto; +use std::fmt::Debug; +use std::ops::{Deref, DerefMut}; + +use bytes::Buf; +use sqlx_core::io::{BufStream, Serialize}; +use sqlx_core::net::Stream as NetStream; +use sqlx_core::{Result, Runtime}; + +use crate::protocol::{BackendMessage, BackendMessageType}; + +/// Reads and writes messages to and from the PostgreSQL database server. +/// +/// The logic for serializing data structures into the messages is found +/// mostly in `protocol/`. +/// +/// The first byte of a message identifies the message type, and the next +/// four bytes specify the length of the rest of the message (this length +/// count includes itself, but not the message-type byte). The remaining +/// contents of the message are determined by the message type. For +/// historical reasons, the very first message sent by the client ( +/// the startup message) has no initial message-type byte. +/// +/// +/// +#[allow(clippy::module_name_repetitions)] +pub(crate) struct PgStream { + stream: BufStream>, +} + +impl PgStream { + pub(crate) fn new(stream: NetStream) -> Self { + Self { stream: BufStream::with_capacity(stream, 4096, 1024) } + } + + // all communication is through a stream of messages + pub(crate) fn write_message<'ser, T>(&'ser mut self, message: &T) -> Result<()> + where + T: Serialize<'ser> + Debug, + { + Ok(()) + } + + // reads and consumes a message from the stream buffer + // assumes there is a message on the stream + fn read_message(&mut self, size: usize) -> Result> { + // the first byte is the message type + let ty = self.stream.get(0, 1).get_u8(); + let ty: BackendMessageType = ty.try_into()?; + + // the next 4 bytes was the length of the message + self.stream.consume(5); + + // and now take the message contents + let contents = self.stream.take(size); + + if contents.len() != size { + // TODO: return a database error + // BUG: something is very wrong somewhere if this branch is executed + // either in the SQLx Postgres driver or in the Postgres server + unimplemented!( + "Received {} bytes for packet but expecting {} bytes", + contents.len(), + size + ); + } + + match ty { + BackendMessageType::ErrorResponse => { + // TODO: return a proper error + unimplemented!("error response"); + } + + BackendMessageType::NotificationResponse => { + // TODO: handle these similar to master + Ok(None) + } + + BackendMessageType::NoticeResponse => { + // TODO: log the incoming message + Ok(None) + } + + BackendMessageType::ParameterStatus => { + // TODO: pull out and remember server version + Ok(None) + } + + _ => Ok(Some(BackendMessage { contents, r#type: ty })), + } + } +} + +macro_rules! impl_read_message { + ($(@$blocking:ident)? $self:ident) => {{ + Ok(loop { + // reads at least 5 bytes from the IO stream into the read buffer + impl_read_message!($(@$blocking)? @stream $self, 0, 5); + + // bytes 1..4 will be the length of the message + let size = ($self.stream.get(1, 4).get_u32() - 4) as usize; + + // read bytes _after_ the header + impl_read_message!($(@$blocking)? @stream $self, 4, size); + + if let Some(message) = $self.read_message(size)? { + break message; + } + }) + }}; + + (@blocking @stream $self:ident, $offset:expr, $n:expr) => { + $self.stream.read($offset, $n)?; + }; + + (@stream $self:ident, $offset:expr, $n:expr) => { + $self.stream.read_async($offset, $n).await?; + }; +} + +impl PgStream { + #[cfg(feature = "async")] + pub(crate) async fn read_message_async(&mut self) -> Result + where + Rt: sqlx_core::Async, + { + impl_read_message!(self) + } + + #[cfg(feature = "blocking")] + pub(crate) fn read_message_blocking(&mut self) -> Result + where + Rt: sqlx_core::blocking::Runtime, + { + impl_read_message!(@blocking self) + } +} + +impl Deref for PgStream { + type Target = BufStream>; + + fn deref(&self) -> &Self::Target { + &self.stream + } +} + +impl DerefMut for PgStream { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.stream + } +} + +macro_rules! read_message { + (@blocking $stream:expr) => { + $stream.read_message_blocking()? + }; + + ($stream:expr) => { + $stream.read_message_async().await? + }; +} diff --git a/sqlx-postgres/src/type_id.rs b/sqlx-postgres/src/type_id.rs new file mode 100644 index 00000000..f10ef745 --- /dev/null +++ b/sqlx-postgres/src/type_id.rs @@ -0,0 +1,120 @@ +/// A unique identifier for a Postgres data type. +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +#[cfg_attr( + any(feature = "offline", feature = "serde"), + derive(serde::Serialize, serde::Deserialize) +)] +#[allow(clippy::module_name_repetitions)] +pub enum PgTypeId { + Oid(u32), + Name(&'static str), +} + +// Data Types +// https://www.postgresql.org/docs/current/datatype.html + +impl PgTypeId { + // Boolean + // https://www.postgresql.org/docs/current/datatype-boolean.html + + /// The SQL standard `boolean` type. + /// + /// Maps to `bool`. + /// + pub const BOOLEAN: Self = Self::Oid(16); + + // Integers + // https://www.postgresql.org/docs/current/datatype-numeric.html#DATATYPE-INT + + /// A 2-byte integer. + /// + /// Compatible with any primitive integer type. + /// + /// Maps to `i16`. + /// + #[doc(alias = "INT2")] + #[doc(alias = "SMALLSERIAL")] + pub const SMALLINT: Self = Self::Oid(21); + + /// A 4-byte integer. + /// + /// Compatible with any primitive integer type. + /// + /// Maps to `i32`. + /// + #[doc(alias = "INT4")] + #[doc(alias = "SERIAL")] + pub const INTEGER: Self = Self::Oid(23); + + /// An 8-byte integer. + /// + /// Compatible with any primitive integer type. + /// + /// Maps to `i64`. + /// + #[doc(alias = "INT8")] + #[doc(alias = "BIGSERIAL")] + pub const BIGINT: Self = Self::Oid(20); + + // Arbitrary Precision Numbers + // https://www.postgresql.org/docs/current/datatype-numeric.html#DATATYPE-NUMERIC-DECIMAL + + /// An exact numeric type with a user-specified precision. + /// + /// Compatible with [`bigdecimal::BigDecimal`], [`rust_decimal::Decimal`], [`num_int::BigInt`], and any + /// primitive integer type. Truncation or loss-of-precision is considered an error + /// when decoding into the selected Rust integer type. + /// + /// With a scale of `0` (e.g, `NUMERIC(17, 0)`), maps to `num_int::BigInt`; otherwise, + /// maps to [`bigdecimal::BigDecimal`] or [`rust_decimal::Decimal`] (depending on + /// enabled crate features). + /// + #[doc(alias = "DECIMAL")] + pub const NUMERIC: Self = Self::Oid(1700); + + // Floating-Point + // https://www.postgresql.org/docs/current/datatype-numeric.html#DATATYPE-FLOAT + + /// A 4-byte floating-point numeric type. + /// + /// Compatible with `f32` or `f64`. + /// + /// Maps to `f32`. + /// + #[doc(alias = "FLOAT4")] + pub const REAL: Self = Self::Oid(700); + + /// An 8-byte floating-point numeric type. + /// + /// Compatible with `f32` or `f64`. + /// + /// Maps to `f64`. + /// + #[doc(alias = "FLOAT8")] + pub const DOUBLE: Self = Self::Oid(701); + + /// The `UNKNOWN` Postgres type. Returned for expressions that do not + /// have a type (e.g., `SELECT $1` with no parameter type hint + /// or `SELECT NULL`). + pub const UNKNOWN: Self = Self::Oid(705); +} + +impl PgTypeId { + #[must_use] + pub(crate) const fn name(self) -> &'static str { + match self { + Self::BOOLEAN => "BOOLEAN", + + Self::SMALLINT => "SMALLINT", + Self::INTEGER => "INTEGER", + Self::BIGINT => "BIGINT", + + Self::NUMERIC => "NUMERIC", + + Self::REAL => "REAL", + Self::DOUBLE => "DOUBLE", + + _ => "UNKNOWN", + } + } +} diff --git a/sqlx-postgres/src/type_info.rs b/sqlx-postgres/src/type_info.rs new file mode 100644 index 00000000..dc528b28 --- /dev/null +++ b/sqlx-postgres/src/type_info.rs @@ -0,0 +1,42 @@ +use sqlx_core::TypeInfo; + +use crate::{PgTypeId, Postgres}; + +/// Provides information about a Postgres type. +#[derive(Debug, Clone, Copy)] +#[cfg_attr( + any(feature = "offline", feature = "serde"), + derive(serde::Serialize, serde::Deserialize) +)] +#[allow(clippy::module_name_repetitions)] +pub struct PgTypeInfo(pub(crate) PgTypeId); + +impl PgTypeInfo { + /// Returns the unique identifier for this Postgres type. + #[must_use] + pub const fn id(&self) -> PgTypeId { + self.0 + } + + /// Returns the name for this Postgres type. + #[must_use] + pub const fn name(&self) -> &'static str { + self.0.name() + } +} + +impl TypeInfo for PgTypeInfo { + type Database = Postgres; + + fn id(&self) -> PgTypeId { + self.id() + } + + fn is_unknown(&self) -> bool { + matches!(self.0, PgTypeId::UNKNOWN) + } + + fn name(&self) -> &'static str { + self.name() + } +} diff --git a/sqlx-postgres/src/types.rs b/sqlx-postgres/src/types.rs new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/sqlx-postgres/src/types.rs @@ -0,0 +1 @@ +