From d5053d1b1df53ecae0394ff5ddd5302b7cdf54ca Mon Sep 17 00:00:00 2001 From: Daniel Akhterov Date: Thu, 21 Jan 2021 23:37:18 -0800 Subject: [PATCH] feat: begin work on postgres --- Cargo.lock | 36 +++- Cargo.toml | 1 + sqlx-core/Cargo.toml | 2 +- sqlx-postgres/Cargo.toml | 51 ++++++ sqlx-postgres/src/connection.rs | 121 +++++++++++++ sqlx-postgres/src/connection/close.rs | 32 ++++ sqlx-postgres/src/connection/connect.rs | 179 +++++++++++++++++++ sqlx-postgres/src/connection/ping.rs | 22 +++ sqlx-postgres/src/connection/stream.rs | 186 ++++++++++++++++++++ sqlx-postgres/src/database.rs | 15 ++ sqlx-postgres/src/error.rs | 23 +++ sqlx-postgres/src/io.rs | 3 + sqlx-postgres/src/io/write.rs | 52 ++++++ sqlx-postgres/src/lib.rs | 34 ++++ sqlx-postgres/src/options.rs | 103 +++++++++++ sqlx-postgres/src/options/builder.rs | 82 +++++++++ sqlx-postgres/src/options/default.rs | 38 +++++ sqlx-postgres/src/options/getters.rs | 55 ++++++ sqlx-postgres/src/options/parse.rs | 183 ++++++++++++++++++++ sqlx-postgres/src/protocol.rs | 91 ++++++++++ sqlx-postgres/src/protocol/close.rs | 36 ++++ sqlx-postgres/src/protocol/flush.rs | 18 ++ sqlx-postgres/src/protocol/notification.rs | 28 +++ sqlx-postgres/src/protocol/response.rs | 190 +++++++++++++++++++++ sqlx-postgres/src/protocol/startup.rs | 64 +++++++ sqlx-postgres/src/protocol/terminate.rs | 14 ++ sqlx/Cargo.toml | 6 + sqlx/src/lib.rs | 3 + sqlx/src/postgres/blocking.rs | 2 + sqlx/src/postgres/blocking/connection.rs | 59 +++++++ sqlx/src/postgres/blocking/options.rs | 28 +++ sqlx/src/postgres/connection.rs | 93 ++++++++++ sqlx/src/postgres/database.rs | 15 ++ sqlx/src/postgres/options.rs | 95 +++++++++++ sqlx/src/postgres/options/builder.rs | 74 ++++++++ sqlx/src/postgres/options/getters.rs | 62 +++++++ x.py | 3 + 37 files changed, 2096 insertions(+), 3 deletions(-) create mode 100644 sqlx-postgres/Cargo.toml create mode 100644 sqlx-postgres/src/connection.rs create mode 100644 sqlx-postgres/src/connection/close.rs create mode 100644 sqlx-postgres/src/connection/connect.rs create mode 100644 sqlx-postgres/src/connection/ping.rs create mode 100644 sqlx-postgres/src/connection/stream.rs create mode 100644 sqlx-postgres/src/database.rs create mode 100644 sqlx-postgres/src/error.rs create mode 100644 sqlx-postgres/src/io.rs create mode 100644 sqlx-postgres/src/io/write.rs create mode 100644 sqlx-postgres/src/lib.rs create mode 100644 sqlx-postgres/src/options.rs create mode 100644 sqlx-postgres/src/options/builder.rs create mode 100644 sqlx-postgres/src/options/default.rs create mode 100644 sqlx-postgres/src/options/getters.rs create mode 100644 sqlx-postgres/src/options/parse.rs create mode 100644 sqlx-postgres/src/protocol.rs create mode 100644 sqlx-postgres/src/protocol/close.rs create mode 100644 sqlx-postgres/src/protocol/flush.rs create mode 100644 sqlx-postgres/src/protocol/notification.rs create mode 100644 sqlx-postgres/src/protocol/response.rs create mode 100644 sqlx-postgres/src/protocol/startup.rs create mode 100644 sqlx-postgres/src/protocol/terminate.rs create mode 100644 sqlx/src/postgres/blocking.rs create mode 100644 sqlx/src/postgres/blocking/connection.rs create mode 100644 sqlx/src/postgres/blocking/options.rs create mode 100644 sqlx/src/postgres/connection.rs create mode 100644 sqlx/src/postgres/database.rs create mode 100644 sqlx/src/postgres/options.rs create mode 100644 sqlx/src/postgres/options/builder.rs create mode 100644 sqlx/src/postgres/options/getters.rs diff --git a/Cargo.lock b/Cargo.lock index 3c7cdf4c..3450b9ba 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -39,8 +39,8 @@ dependencies = [ [[package]] name = "async-compat" -version = "0.1.5" -source = "git+https://github.com/taiki-e/async-compat?branch=tokio1#8d87a0917ebe27e4e3caa944d2991d26b1050fb0" +version = "0.2.0" +source = "git+https://github.com/smol-rs/async-compat?branch=master#e1c197b19788fb8f449c72095bf7a9e72e3b95b0" dependencies = [ "futures-core", "futures-io", @@ -562,6 +562,12 @@ dependencies = [ "cfg-if 1.0.0", ] +[[package]] +name = "itoa" +version = "0.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd25036021b0de88a0aff6b850051563c6516d0bf53f8638938edbb9de732736" + [[package]] name = "js-sys" version = "0.3.46" @@ -1051,6 +1057,7 @@ dependencies = [ "futures-util", "sqlx-core", "sqlx-mysql", + "sqlx-postgres", ] [[package]] @@ -1098,6 +1105,31 @@ dependencies = [ "url", ] +[[package]] +name = "sqlx-postgres" +version = "0.6.0-pre" +dependencies = [ + "anyhow", + "base64", + "bitflags", + "bytes", + "bytestring", + "either", + "futures-executor", + "futures-io", + "futures-util", + "itoa", + "log", + "memchr", + "percent-encoding", + "rand", + "rsa", + "sha-1", + "sha2", + "sqlx-core", + "url", +] + [[package]] name = "subtle" version = "2.4.0" diff --git a/Cargo.toml b/Cargo.toml index 60632f9f..b96d900b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,5 +4,6 @@ default-members = ["sqlx"] members = [ "sqlx-core", "sqlx-mysql", + "sqlx-postgres", "sqlx", ] diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index 2e7ec8dc..22fe7d8d 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -43,7 +43,7 @@ actix-rt = { version = "2.0.0-beta.2", optional = true } _async-std = { version = "1.8", optional = true, package = "async-std" } futures-util = { version = "0.3", optional = true, features = ["io"] } _tokio = { version = "1.0", optional = true, package = "tokio", features = ["net", "io-util"] } -async-compat = { version = "*", git = "https://github.com/taiki-e/async-compat", branch = "tokio1", optional = true } +async-compat = { version = "*", git = "https://github.com/smol-rs/async-compat", branch = "master", optional = true } futures-io = { version = "0.3", optional = true } futures-core = { version = "0.3", optional = true } bytes = "1.0" diff --git a/sqlx-postgres/Cargo.toml b/sqlx-postgres/Cargo.toml new file mode 100644 index 00000000..40abec0e --- /dev/null +++ b/sqlx-postgres/Cargo.toml @@ -0,0 +1,51 @@ +[package] +name = "sqlx-postgres" +version = "0.6.0-pre" +repository = "https://github.com/launchbadge/sqlx" +description = "MySQL database driver for SQLx, the Rust SQL Toolkit." +license = "MIT OR Apache-2.0" +edition = "2018" +keywords = ["postgres", "sqlx", "database"] +categories = ["database", "asynchronous"] +authors = [ + "LaunchBadge " +] + +[package.metadata.docs.rs] +# > RUSTDOCFLAGS="--cfg doc_cfg" cargo +nightly doc --all-features --no-deps --open +all-features = true +rustdoc-args = ["--cfg", "doc_cfg"] + +[features] +default = [] + +# blocking (std) runtime +blocking = ["sqlx-core/blocking"] + +# async runtime +# not meant to be used directly +async = ["futures-util", "sqlx-core/async", "futures-io"] + +[dependencies] +sqlx-core = { version = "0.6.0-pre", path = "../sqlx-core" } +futures-util = { version = "0.3.8", optional = true } +either = "1.6.1" +log = "0.4.11" +bytestring = "1.0.0" +url = "2.2.0" +percent-encoding = "2.1.0" +futures-io = { version = "0.3", optional = true } +bytes = "1.0" +memchr = "2.3" +bitflags = "1.2" +sha-1 = "0.9.2" +sha2 = "0.9.2" +rsa = "0.3.0" +base64 = "0.13.0" +rand = "0.7" +itoa = "0.4.7" + +[dev-dependencies] +sqlx-core = { version = "0.6.0-pre", path = "../sqlx-core", features = ["_mock"] } +futures-executor = "0.3.8" +anyhow = "1.0.37" diff --git a/sqlx-postgres/src/connection.rs b/sqlx-postgres/src/connection.rs new file mode 100644 index 00000000..55b86835 --- /dev/null +++ b/sqlx-postgres/src/connection.rs @@ -0,0 +1,121 @@ +use std::fmt::{self, Debug, Formatter}; + +use sqlx_core::io::BufStream; +use sqlx_core::net::Stream as NetStream; +use sqlx_core::{Close, Connect, Connection, Runtime}; + +use crate::{Postgres, PostgresConnectOptions}; + +mod close; +mod connect; +mod ping; +mod stream; + +/// A single connection (also known as a session) to a PostgreSQL database server. +#[allow(clippy::module_name_repetitions)] +pub struct PostgresConnection +where + Rt: Runtime, +{ + stream: BufStream>, + + // process id of this backend + // used to send cancel requests + #[allow(dead_code)] + process_id: u32, + + // secret key of this backend + // used to send cancel requests + #[allow(dead_code)] + secret_key: u32, +} + +impl PostgresConnection +where + Rt: Runtime, +{ + pub(crate) fn new(stream: NetStream) -> Self { + Self { stream: BufStream::with_capacity(stream, 4096, 1024), process_id: 0, secret_key: 0 } + } +} + +impl Debug for PostgresConnection +where + Rt: Runtime, +{ + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("PostgresConnection").finish() + } +} + +impl Connection for PostgresConnection +where + Rt: Runtime, +{ + type Database = Postgres; + + #[cfg(feature = "async")] + fn ping(&mut self) -> futures_util::future::BoxFuture<'_, sqlx_core::Result<()>> + where + Rt: sqlx_core::Async, + { + Box::pin(self.ping_async()) + } +} + +impl Connect for PostgresConnection { + type Options = PostgresConnectOptions; + + #[cfg(feature = "async")] + fn connect(url: &str) -> futures_util::future::BoxFuture<'_, sqlx_core::Result> + where + Self: Sized, + Rt: sqlx_core::Async, + { + use sqlx_core::ConnectOptions; + + let options = url.parse::(); + Box::pin(async move { options?.connect().await }) + } +} + +impl Close for PostgresConnection { + #[cfg(feature = "async")] + fn close(self) -> futures_util::future::BoxFuture<'static, sqlx_core::Result<()>> + where + Rt: sqlx_core::Async, + { + Box::pin(self.close_async()) + } +} + +#[cfg(feature = "blocking")] +mod blocking { + use sqlx_core::blocking::{Close, Connect, Connection, Runtime}; + + use super::{PostgresConnectOptions, PostgresConnection}; + + impl Connection for PostgresConnection { + #[inline] + fn ping(&mut self) -> sqlx_core::Result<()> { + self.ping() + } + } + + impl Connect for PostgresConnection { + #[inline] + fn connect(url: &str) -> sqlx_core::Result + where + Self: Sized, + { + Self::connect(&url.parse::>()?) + } + } + + impl Close for PostgresConnection { + #[inline] + fn close(self) -> sqlx_core::Result<()> { + self.close() + } + } +} diff --git a/sqlx-postgres/src/connection/close.rs b/sqlx-postgres/src/connection/close.rs new file mode 100644 index 00000000..7bbefeb2 --- /dev/null +++ b/sqlx-postgres/src/connection/close.rs @@ -0,0 +1,32 @@ +use sqlx_core::{io::Stream, Result, Runtime}; + +use crate::protocol::Terminate; + +impl super::PostgresConnection +where + Rt: Runtime, +{ + #[cfg(feature = "async")] + pub(crate) async fn close_async(mut self) -> Result<()> + where + Rt: sqlx_core::Async, + { + self.write_packet(&Terminate)?; + self.stream.flush_async().await?; + self.stream.shutdown_async().await?; + + Ok(()) + } + + #[cfg(feature = "blocking")] + pub(crate) fn close(mut self) -> Result<()> + where + Rt: sqlx_core::blocking::Runtime, + { + self.write_packet(&Terminate)?; + self.stream.flush()?; + self.stream.shutdown()?; + + Ok(()) + } +} diff --git a/sqlx-postgres/src/connection/connect.rs b/sqlx-postgres/src/connection/connect.rs new file mode 100644 index 00000000..33fe7396 --- /dev/null +++ b/sqlx-postgres/src/connection/connect.rs @@ -0,0 +1,179 @@ +//! Implements the connection phase. +//! +//! The connection phase (establish) performs these tasks: +//! +//! - exchange the capabilities of client and server +//! - setup SSL communication channel if requested +//! - authenticate the client against the server +//! +//! The server may immediately send an ERR packet and finish the handshake +//! or send a `Handshake`. +//! +//! https://dev.postgres.com/doc/internals/en/connection-phase.html +//! +use sqlx_core::net::Stream as NetStream; +use sqlx_core::Error; +use sqlx_core::Result; + +use crate::protocol::{Message, MessageType, Startup}; +use crate::{PostgresConnectOptions, PostgresConnection}; + +macro_rules! connect { + (@blocking @tcp $options:ident) => { + NetStream::connect($options.address.as_ref())?; + }; + + (@tcp $options:ident) => { + NetStream::connect_async($options.address.as_ref()).await?; + }; + + (@blocking @packet $self:ident) => { + $self.read_message()?; + }; + + (@packet $self:ident) => { + $self.read_message_async().await?; + }; + + ($(@$blocking:ident)? $options:ident) => {{ + // open a network stream to the database server + let stream = connect!($(@$blocking)? @tcp $options); + + // construct a around the network stream + // wraps the stream in a to buffer read and write + let mut self_ = Self::new(stream); + + // To begin a session, a frontend opens a connection to the server + // and sends a startup message. + + let mut params = vec![ + // Sets the display format for date and time values, + // as well as the rules for interpreting ambiguous date input values. + ("DateStyle", "ISO, MDY"), + // Sets the client-side encoding (character set). + // + ("client_encoding", "UTF8"), + // Sets the time zone for displaying and interpreting time stamps. + ("TimeZone", "UTC"), + // Adjust postgres to return precise values for floats + // NOTE: This is default in postgres 12+ + ("extra_float_digits", "3"), + ]; + + // if let Some(ref application_name) = $options.get_application_name() { + // params.push(("application_name", application_name)); + // } + + self_.write_packet(&Startup { + username: $options.get_username(), + database: $options.get_database(), + params: ¶ms, + })?; + + // The server then uses this information and the contents of + // its configuration files (such as pg_hba.conf) to determine whether the connection is + // provisionally acceptable, and what additional + // authentication is required (if any). + + let mut process_id = 0; + let mut secret_key = 0; + let transaction_status; + + loop { + let message: Message = connect!($(@$blocking)? @packet self_); + match message.r#type { + MessageType::Authentication => match message.decode()? { + // Authentication::Ok => { + // the authentication exchange is successfully completed + // do nothing; no more information is required to continue + // } + + // Authentication::CleartextPassword => { + // // The frontend must now send a [PasswordMessage] containing the + // // password in clear-text form. + + // stream + // .send(Password::Cleartext( + // options.password.as_deref().unwrap_or_default(), + // )) + // .await?; + // } + + // Authentication::Md5Password(body) => { + // // The frontend must now send a [PasswordMessage] containing the + // // password (with user name) encrypted via MD5, then encrypted again + // // using the 4-byte random salt specified in the + // // [AuthenticationMD5Password] message. + + // stream + // .send(Password::Md5 { + // username: &options.username, + // password: options.password.as_deref().unwrap_or_default(), + // salt: body.salt, + // }) + // .await?; + // } + + // Authentication::Sasl(body) => { + // sasl::authenticate(&mut stream, options, body).await?; + // } + + method => { + return Err(Error::configuration_msg(format!( + "unsupported authentication method: {:?}", + method + ))); + } + }, + + // MessageFormat::BackendKeyData => { + // // provides secret-key data that the frontend must save if it wants to be + // // able to issue cancel requests later + + // let data: BackendKeyData = message.decode()?; + + // process_id = data.process_id; + // secret_key = data.secret_key; + // } + + // MessageFormat::ReadyForQuery => { + // // start-up is completed. The frontend can now issue commands + // transaction_status = + // ReadyForQuery::decode(message.contents)?.transaction_status; + + // break; + // } + + _ => { + return Err(Error::configuration_msg(format!( + "establish: unexpected message: {:?}", + message.format + ))) + } + } + } + + Ok(self_) + }}; +} + +impl PostgresConnection +where + Rt: sqlx_core::Runtime, +{ + #[cfg(feature = "async")] + pub(crate) async fn connect_async(options: &PostgresConnectOptions) -> Result + where + Rt: sqlx_core::Async, + { + connect!(options) + } + + #[cfg(feature = "blocking")] + pub(crate) fn connect(options: &PostgresConnectOptions) -> Result + where + Rt: sqlx_core::blocking::Runtime, + { + connect!(@blocking options) + } +} diff --git a/sqlx-postgres/src/connection/ping.rs b/sqlx-postgres/src/connection/ping.rs new file mode 100644 index 00000000..9e3c5ee0 --- /dev/null +++ b/sqlx-postgres/src/connection/ping.rs @@ -0,0 +1,22 @@ +use sqlx_core::{Result, Runtime}; + +impl super::PostgresConnection +where + Rt: Runtime, +{ + #[cfg(feature = "async")] + pub(crate) async fn ping_async(&mut self) -> Result<()> + where + Rt: sqlx_core::Async, + { + todo!(); + } + + #[cfg(feature = "blocking")] + pub(crate) fn ping(&mut self) -> Result<()> + where + Rt: sqlx_core::blocking::Runtime, + { + todo!(); + } +} diff --git a/sqlx-postgres/src/connection/stream.rs b/sqlx-postgres/src/connection/stream.rs new file mode 100644 index 00000000..9529fede --- /dev/null +++ b/sqlx-postgres/src/connection/stream.rs @@ -0,0 +1,186 @@ +//! Reads and writes packets to and from the PostgreSQL database server. +//! +//! The logic for serializing data structures into the packets is found +//! mostly in `protocol/`. +//! +//! Packets in PostgreSQL are prefixed by 4 bytes. +//! 3 for length (in LE) and a sequence id. +//! +//! Packets may only be as large as the communicated size in the initial +//! `HandshakeResponse`. By default, SQLx configures its chunk size to 16M. Sending +//! a larger payload is simply sending completely "full" packets, one after the +//! other, with an increasing sequence id. +//! +//! In other words, when we sent data, we: +//! +//! - Split the data into "packets" of size `2 ** 24 - 1` bytes. +//! +//! - Prepend each packet with a **packet header**, consisting of the length of that packet, +//! and the sequence number. +//! +//! https://dev.postgres.com/doc/internals/en/postgres-packet.html +//! +use std::convert::TryFrom; +use std::fmt::Debug; + +use bytes::{Buf, Bytes}; +use log::Level; +use sqlx_core::io::{Deserialize, Serialize}; +use sqlx_core::{Result, Runtime}; + +use crate::protocol::{Message, MessageType, Notice, PgSeverity}; +use crate::PostgresConnection; + +impl PostgresConnection +where + Rt: Runtime, +{ + pub(super) fn write_packet<'ser, T>(&'ser mut self, packet: &T) -> Result<()> + where + T: Serialize<'ser, ()> + Debug, + { + log::trace!("write > {:?}", packet); + + let buf = self.stream.buffer(); + packet.serialize_with(buf, ())?; + + Ok(()) + } + + pub(crate) fn recv_message(&mut self) -> Result { + // all packets in postgres start with a 5-byte header + // this header contains the message type and the total length of the message + let mut header: Bytes = self.stream.take(5); + + let r#type = MessageType::try_from(header.get_u8())?; + let size = (header.get_u32() - 4) as usize; + + let contents = self.stream.take(size); + + Ok(Message { r#type, contents }) + } + + fn recv_packet<'de, T>(&'de mut self, len: usize) -> Result + where + T: Deserialize<'de, ()> + Debug, + { + loop { + let message = self.recv_message()?; + + match message.r#type { + MessageType::ErrorResponse => { + // An error returned from the database server. + // return Err(PgDatabaseError(message.decode()?).into()); + panic!("got error response"); + } + + MessageType::NotificationResponse => { + // if let Some(buffer) = &mut self.notifications { + // let notification: Notification = message.decode()?; + // let _ = self.write_packet(notification); + + // continue; + // } + continue; + } + + MessageType::ParameterStatus => { + // informs the frontend about the current (initial) + // setting of backend parameters + + // we currently have no use for that data so we promptly ignore this message + continue; + } + + MessageType::NoticeResponse => { + // do we need this to be more configurable? + // if you are reading this comment and think so, open an issue + + let notice: Notice = message.decode()?; + + let lvl = match notice.severity() { + PgSeverity::Fatal | PgSeverity::Panic | PgSeverity::Error => Level::Error, + PgSeverity::Warning => Level::Warn, + PgSeverity::Notice => Level::Info, + PgSeverity::Debug => Level::Debug, + PgSeverity::Info => Level::Trace, + PgSeverity::Log => Level::Trace, + }; + + if lvl <= log::STATIC_MAX_LEVEL && lvl <= log::max_level() { + log::logger().log( + &log::Record::builder() + .args(format_args!("{}", notice.message())) + .level(lvl) + .module_path_static(Some("sqlx::postgres::notice")) + .file_static(Some(file!())) + .line(Some(line!())) + .build(), + ); + } + + continue; + } + + _ => {} + } + + return T::deserialize_with(message.contents, ()); + } + } +} + +macro_rules! read_packet { + ($(@$blocking:ident)? $self:ident) => {{ + // reads at least 4 bytes from the IO stream into the read buffer + read_packet!($(@$blocking)? @stream $self, 0, 4); + + // the first 3 bytes will be the payload length of the packet (in LE) + // ALLOW: the max this len will be is 16M + #[allow(clippy::cast_possible_truncation)] + let payload_len: usize = $self.stream.get(0, 3).get_uint_le(3) as usize; + + // read bytes _after_ the 4 byte packet header + // note that we have not yet told the stream we are done with any of + // these bytes yet. if this next read invocation were to never return (eg., the + // outer future was dropped), then the next time read_packet_async was called + // it will re-read the parsed-above packet header. Note that we have NOT + // mutated `self` _yet_. This is important. + read_packet!($(@$blocking)? @stream $self, 4, payload_len); + + $self.recv_packet(payload_len) + }}; + + (@blocking @stream $self:ident, $offset:expr, $n:expr) => { + $self.stream.read($offset, $n)?; + }; + + (@stream $self:ident, $offset:expr, $n:expr) => { + $self.stream.read_async($offset, $n).await?; + }; +} + +impl PostgresConnection +where + Rt: Runtime, +{ + #[cfg(feature = "async")] + + pub(super) async fn read_packet_async<'de, T>(&'de mut self) -> Result + where + T: Deserialize<'de, ()> + Debug, + Rt: sqlx_core::Async, + { + read_packet!(self) + } + + #[cfg(feature = "blocking")] + + pub(super) fn read_packet<'de, T>(&'de mut self) -> Result + where + T: Deserialize<'de, ()> + Debug, + Rt: sqlx_core::blocking::Runtime, + { + read_packet!(@blocking self) + } +} diff --git a/sqlx-postgres/src/database.rs b/sqlx-postgres/src/database.rs new file mode 100644 index 00000000..b3c6508e --- /dev/null +++ b/sqlx-postgres/src/database.rs @@ -0,0 +1,15 @@ +use sqlx_core::{Database, HasOutput, Runtime}; + +#[derive(Debug)] +pub struct Postgres; + +impl Database for Postgres +where + Rt: Runtime, +{ + type Connection = super::PostgresConnection; +} + +impl<'x> HasOutput<'x> for Postgres { + type Output = &'x mut Vec; +} diff --git a/sqlx-postgres/src/error.rs b/sqlx-postgres/src/error.rs new file mode 100644 index 00000000..94395a14 --- /dev/null +++ b/sqlx-postgres/src/error.rs @@ -0,0 +1,23 @@ +use std::error::Error as StdError; +use std::fmt::{self, Display, Formatter}; + +use sqlx_core::DatabaseError; + +/// An error returned from the PostgreSQL database server. +#[allow(clippy::module_name_repetitions)] +#[derive(Debug)] +pub struct PostgresDatabaseError(); + +impl DatabaseError for PostgresDatabaseError { + fn message(&self) -> &str { + todo!() + } +} + +impl Display for PostgresDatabaseError { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "TODO") + } +} + +impl StdError for PostgresDatabaseError {} diff --git a/sqlx-postgres/src/io.rs b/sqlx-postgres/src/io.rs new file mode 100644 index 00000000..51f04761 --- /dev/null +++ b/sqlx-postgres/src/io.rs @@ -0,0 +1,3 @@ +mod write; + +pub(crate) use write::PgBufMutExt; diff --git a/sqlx-postgres/src/io/write.rs b/sqlx-postgres/src/io/write.rs new file mode 100644 index 00000000..4b13b7c9 --- /dev/null +++ b/sqlx-postgres/src/io/write.rs @@ -0,0 +1,52 @@ +pub trait PgBufMutExt { + fn write_length_prefixed(&mut self, f: F) + where + F: FnOnce(&mut Vec); + + fn write_statement_name(&mut self, id: u32); + + fn write_portal_name(&mut self, id: Option); +} + +impl PgBufMutExt for Vec { + // writes a length-prefixed message, this is used when encoding nearly all messages as postgres + // wants us to send the length of the often-variable-sized messages up front + fn write_length_prefixed(&mut self, f: F) + where + F: FnOnce(&mut Vec), + { + // reserve space to write the prefixed length + let offset = self.len(); + self.extend(&[0; 4]); + + // write the main body of the message + f(self); + + // now calculate the size of what we wrote and set the length value + let size = (self.len() - offset) as i32; + self[offset..(offset + 4)].copy_from_slice(&size.to_be_bytes()); + } + + // writes a statement name by ID + #[inline] + fn write_statement_name(&mut self, id: u32) { + // N.B. if you change this don't forget to update it in ../describe.rs + self.extend(b"sqlx_s_"); + + itoa::write(&mut *self, id).unwrap(); + + self.push(0); + } + + // writes a portal name by ID + #[inline] + fn write_portal_name(&mut self, id: Option) { + if let Some(id) = id { + self.extend(b"sqlx_p_"); + + itoa::write(&mut *self, id).unwrap(); + } + + self.push(0); + } +} diff --git a/sqlx-postgres/src/lib.rs b/sqlx-postgres/src/lib.rs new file mode 100644 index 00000000..e8ece223 --- /dev/null +++ b/sqlx-postgres/src/lib.rs @@ -0,0 +1,34 @@ +//! [PostgreSQL] database driver. +//! +//! [PostgreSQL]: https://www.postgres.com/ +//! +#![cfg_attr(doc_cfg, feature(doc_cfg))] +#![cfg_attr(not(any(feature = "async", feature = "blocking")), allow(unused))] +#![deny(unsafe_code)] +#![warn(rust_2018_idioms)] +#![warn(future_incompatible)] +#![warn(clippy::pedantic)] +#![warn(clippy::multiple_crate_versions)] +#![warn(clippy::cognitive_complexity)] +#![warn(clippy::future_not_send)] +#![warn(clippy::missing_const_for_fn)] +#![warn(clippy::needless_borrow)] +#![warn(clippy::string_lit_as_bytes)] +#![warn(clippy::use_self)] +#![warn(clippy::useless_let_if_seq)] +#![allow(clippy::doc_markdown)] + +mod connection; +mod database; +mod error; +mod io; +mod options; +mod protocol; + +#[cfg(test)] +mod mock; + +pub use connection::PostgresConnection; +pub use database::Postgres; +pub use error::PostgresDatabaseError; +pub use options::PostgresConnectOptions; diff --git a/sqlx-postgres/src/options.rs b/sqlx-postgres/src/options.rs new file mode 100644 index 00000000..06ee91bd --- /dev/null +++ b/sqlx-postgres/src/options.rs @@ -0,0 +1,103 @@ +use std::fmt::{self, Debug, Formatter}; +use std::marker::PhantomData; +use std::path::PathBuf; + +use either::Either; +use sqlx_core::{ConnectOptions, Runtime}; + +use crate::PostgresConnection; + +mod builder; +mod default; +mod getters; +mod parse; + +// TODO: RSA Public Key (to avoid the key exchange for caching_sha2 and sha256 plugins) + +/// Options which can be used to configure how a PostgreSQL connection is opened. +/// +#[allow(clippy::module_name_repetitions)] +pub struct PostgresConnectOptions +where + Rt: Runtime, +{ + runtime: PhantomData, + pub(crate) address: Either<(String, u16), PathBuf>, + username: Option, + password: Option, + database: Option, + timezone: String, + charset: String, +} + +impl Clone for PostgresConnectOptions +where + Rt: Runtime, +{ + fn clone(&self) -> Self { + Self { + runtime: PhantomData, + address: self.address.clone(), + username: self.username.clone(), + password: self.password.clone(), + database: self.database.clone(), + timezone: self.timezone.clone(), + charset: self.charset.clone(), + } + } +} + +impl Debug for PostgresConnectOptions +where + Rt: Runtime, +{ + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("PostgresConnectOptions") + .field( + "address", + &self + .address + .as_ref() + .map_left(|(host, port)| format!("{}:{}", host, port)) + .map_right(|socket| socket.display()), + ) + .field("username", &self.username) + .field("password", &self.password) + .field("database", &self.database) + .field("timezone", &self.timezone) + .field("charset", &self.charset) + .finish() + } +} + +impl ConnectOptions for PostgresConnectOptions +where + Rt: Runtime, +{ + type Connection = PostgresConnection; + + #[cfg(feature = "async")] + fn connect(&self) -> futures_util::future::BoxFuture<'_, sqlx_core::Result> + where + Self::Connection: Sized, + Rt: sqlx_core::Async, + { + Box::pin(PostgresConnection::::connect_async(self)) + } +} + +#[cfg(feature = "blocking")] +mod blocking { + use sqlx_core::blocking::{ConnectOptions, Runtime}; + + use super::{PostgresConnectOptions, PostgresConnection}; + + impl ConnectOptions for PostgresConnectOptions { + fn connect(&self) -> sqlx_core::Result + where + Self::Connection: Sized, + { + >::connect(self) + } + } +} diff --git a/sqlx-postgres/src/options/builder.rs b/sqlx-postgres/src/options/builder.rs new file mode 100644 index 00000000..07feb7de --- /dev/null +++ b/sqlx-postgres/src/options/builder.rs @@ -0,0 +1,82 @@ +use std::mem; +use std::path::{Path, PathBuf}; + +use either::Either; +use sqlx_core::Runtime; + +impl super::PostgresConnectOptions +where + Rt: Runtime, +{ + /// Sets the hostname of the database server. + /// + /// If the hostname begins with a slash (`/`), it is interpreted as the absolute path + /// to a Unix domain socket file instead of a hostname of a server. + /// + /// Defaults to `localhost`. + /// + pub fn host(&mut self, host: impl AsRef) -> &mut Self { + let host = host.as_ref(); + + self.address = if host.starts_with('/') { + Either::Right(PathBuf::from(&*host)) + } else { + Either::Left((host.into(), self.get_port())) + }; + + self + } + + /// Sets the path of the Unix domain socket to connect to. + /// + /// Overrides [`host()`](#method.host) and [`port()`](#method.port). + /// + pub fn socket(&mut self, socket: impl AsRef) -> &mut Self { + self.address = Either::Right(socket.as_ref().to_owned()); + self + } + + /// Sets the TCP port number of the database server. + /// + /// Defaults to `3306`. + /// + pub fn port(&mut self, port: u16) -> &mut Self { + self.address = match self.address { + Either::Right(_) => Either::Left(("localhost".to_owned(), port)), + Either::Left((ref mut host, _)) => Either::Left((mem::take(host), port)), + }; + + self + } + + /// Sets the username to be used for authentication. + // FIXME: Specify what happens when you do NOT set this + pub fn username(&mut self, username: impl AsRef) -> &mut Self { + self.username = Some(username.as_ref().to_owned()); + self + } + + /// Sets the password to be used for authentication. + pub fn password(&mut self, password: impl AsRef) -> &mut Self { + self.password = Some(password.as_ref().to_owned()); + self + } + + /// Sets the default database for the connection. + pub fn database(&mut self, database: impl AsRef) -> &mut Self { + self.database = Some(database.as_ref().to_owned()); + self + } + + /// Sets the character set for the connection. + pub fn charset(&mut self, charset: impl AsRef) -> &mut Self { + self.charset = charset.as_ref().to_owned(); + self + } + + /// Sets the timezone for the connection. + pub fn timezone(&mut self, timezone: impl AsRef) -> &mut Self { + self.timezone = timezone.as_ref().to_owned(); + self + } +} diff --git a/sqlx-postgres/src/options/default.rs b/sqlx-postgres/src/options/default.rs new file mode 100644 index 00000000..053cc20c --- /dev/null +++ b/sqlx-postgres/src/options/default.rs @@ -0,0 +1,38 @@ +use std::marker::PhantomData; + +use either::Either; +use sqlx_core::Runtime; + +use crate::PostgresConnectOptions; + +pub(crate) const HOST: &str = "localhost"; +pub(crate) const PORT: u16 = 3306; + +impl Default for PostgresConnectOptions +where + Rt: Runtime, +{ + fn default() -> Self { + Self { + runtime: PhantomData, + address: Either::Left((HOST.to_owned(), PORT)), + username: None, + password: None, + database: None, + charset: "utf8mb4".to_owned(), + timezone: "utc".to_owned(), + // todo: connect_timeout + } + } +} + +impl super::PostgresConnectOptions +where + Rt: Runtime, +{ + /// Creates a default set of options ready for configuration. + #[must_use] + pub fn new() -> Self { + Self::default() + } +} diff --git a/sqlx-postgres/src/options/getters.rs b/sqlx-postgres/src/options/getters.rs new file mode 100644 index 00000000..b3e7f6f9 --- /dev/null +++ b/sqlx-postgres/src/options/getters.rs @@ -0,0 +1,55 @@ +use std::path::{Path, PathBuf}; + +use sqlx_core::Runtime; + +use super::{default, PostgresConnectOptions}; + +impl PostgresConnectOptions { + /// Returns the hostname of the database server. + #[must_use] + pub fn get_host(&self) -> &str { + self.address.as_ref().left().map_or(default::HOST, |(host, _)| &**host) + } + + /// Returns the TCP port number of the database server. + #[must_use] + pub fn get_port(&self) -> u16 { + self.address.as_ref().left().map_or(default::PORT, |(_, port)| *port) + } + + /// Returns the path to the Unix domain socket, if one is configured. + #[must_use] + pub fn get_socket(&self) -> Option<&Path> { + self.address.as_ref().right().map(PathBuf::as_path) + } + + /// Returns the default database name. + #[must_use] + pub fn get_database(&self) -> Option<&str> { + self.database.as_deref() + } + + /// Returns the username to be used for authentication. + #[must_use] + pub fn get_username(&self) -> Option<&str> { + self.username.as_deref() + } + + /// Returns the password to be used for authentication. + #[must_use] + pub fn get_password(&self) -> Option<&str> { + self.password.as_deref() + } + + /// Returns the character set for the connection. + #[must_use] + pub fn get_charset(&self) -> &str { + &self.charset + } + + /// Returns the timezone for the connection. + #[must_use] + pub fn get_timezone(&self) -> &str { + &self.timezone + } +} diff --git a/sqlx-postgres/src/options/parse.rs b/sqlx-postgres/src/options/parse.rs new file mode 100644 index 00000000..2c395e82 --- /dev/null +++ b/sqlx-postgres/src/options/parse.rs @@ -0,0 +1,183 @@ +use std::borrow::Cow; +use std::str::FromStr; + +use percent_encoding::percent_decode_str; +use sqlx_core::{Error, Runtime}; +use url::Url; + +use crate::PostgresConnectOptions; + +impl FromStr for PostgresConnectOptions +where + Rt: Runtime, +{ + type Err = Error; + + fn from_str(s: &str) -> Result { + let url: Url = + s.parse().map_err(|error| Error::configuration("for database URL", error))?; + + if !matches!(url.scheme(), "postgres") { + return Err(Error::configuration_msg(format!( + "unsupported URL scheme {:?} for MySQL", + url.scheme() + ))); + } + + let mut options = Self::new(); + + if let Some(host) = url.host_str() { + options.host(percent_decode_str_utf8(host)); + } + + if let Some(port) = url.port() { + options.port(port); + } + + let username = url.username(); + if !username.is_empty() { + options.username(percent_decode_str_utf8(username)); + } + + if let Some(password) = url.password() { + options.password(percent_decode_str_utf8(password)); + } + + let mut path = url.path(); + + if path.starts_with('/') { + path = &path[1..]; + } + + if !path.is_empty() { + options.database(path); + } + + for (key, value) in url.query_pairs() { + match &*key { + "user" | "username" => { + options.username(value); + } + + "password" => { + options.password(value); + } + + // ssl-mode compatibly with SQLx <= 0.5 + // sslmode compatibly with PostgreSQL + // sslMode compatibly with JDBC MySQL + // tls compatibly with Go MySQL [preferred] + "ssl-mode" | "sslmode" | "sslMode" | "tls" => { + todo!() + } + + "charset" => { + options.charset(value); + } + + "timezone" => { + options.timezone(value); + } + + "socket" => { + options.socket(&*value); + } + + _ => { + // ignore unknown connection parameters + // fixme: should we error or warn here? + } + } + } + + Ok(options) + } +} + +// todo: this should probably go somewhere common +fn percent_decode_str_utf8(value: &str) -> Cow<'_, str> { + percent_decode_str(value).decode_utf8_lossy() +} + +#[cfg(test)] +mod tests { + use std::path::Path; + + use sqlx_core::mock::Mock; + + use super::PostgresConnectOptions; + + #[test] + fn parse() { + let url = "postgres://user:password@hostname:5432/database?timezone=system&charset=utf8"; + let options: PostgresConnectOptions = url.parse().unwrap(); + + assert_eq!(options.get_username(), Some("user")); + assert_eq!(options.get_password(), Some("password")); + assert_eq!(options.get_host(), "hostname"); + assert_eq!(options.get_port(), 5432); + assert_eq!(options.get_database(), Some("database")); + assert_eq!(options.get_timezone(), "system"); + assert_eq!(options.get_charset(), "utf8"); + } + + #[test] + fn parse_with_defaults() { + let url = "postgres://"; + let options: PostgresConnectOptions = url.parse().unwrap(); + + assert_eq!(options.get_username(), None); + assert_eq!(options.get_password(), None); + assert_eq!(options.get_host(), "localhost"); + assert_eq!(options.get_port(), 3306); + assert_eq!(options.get_database(), None); + assert_eq!(options.get_timezone(), "utc"); + assert_eq!(options.get_charset(), "utf8mb4"); + } + + #[test] + fn parse_socket_from_query() { + let url = "postgres://user:password@localhost/database?socket=/var/run/postgresd/postgresd.sock"; + let options: PostgresConnectOptions = url.parse().unwrap(); + + assert_eq!(options.get_username(), Some("user")); + assert_eq!(options.get_password(), Some("password")); + assert_eq!(options.get_database(), Some("database")); + assert_eq!(options.get_socket(), Some(Path::new("/var/run/postgresd/postgresd.sock"))); + } + + #[test] + fn parse_socket_from_host() { + // socket path in host requires URL encoding - but does work + let url = "postgres://user:password@%2Fvar%2Frun%2Fpostgresd%2Fpostgresd.sock/database"; + let options: PostgresConnectOptions = url.parse().unwrap(); + + assert_eq!(options.get_username(), Some("user")); + assert_eq!(options.get_password(), Some("password")); + assert_eq!(options.get_database(), Some("database")); + assert_eq!(options.get_socket(), Some(Path::new("/var/run/postgresd/postgresd.sock"))); + } + + #[test] + #[should_panic] + fn fail_to_parse_non_postgres() { + let url = "postgres://user:password@hostname:5432/database?timezone=system&charset=utf8"; + let _: PostgresConnectOptions = url.parse().unwrap(); + } + + #[test] + fn parse_username_with_at_sign() { + let url = "postgres://user@hostname:password@hostname:5432/database"; + let options: PostgresConnectOptions = url.parse().unwrap(); + + assert_eq!(options.get_username(), Some("user@hostname")); + } + + #[test] + fn parse_password_with_non_ascii_chars() { + let url = "postgres://username:p@ssw0rd@hostname:5432/database"; + let options: PostgresConnectOptions = url.parse().unwrap(); + + assert_eq!(options.get_password(), Some("p@ssw0rd")); + } +} diff --git a/sqlx-postgres/src/protocol.rs b/sqlx-postgres/src/protocol.rs new file mode 100644 index 00000000..fe00f0b0 --- /dev/null +++ b/sqlx-postgres/src/protocol.rs @@ -0,0 +1,91 @@ +use std::convert::TryFrom; + +use bytes::Bytes; +use sqlx_core::io::Deserialize; +use sqlx_core::Error; +use sqlx_core::Result; + +mod close; +mod notification; +mod response; +mod startup; +mod terminate; + +pub(crate) use close::Close; +pub(crate) use notification::Notification; +pub(crate) use response::{Notice, PgSeverity}; +pub(crate) use startup::Startup; +pub(crate) use terminate::Terminate; + +#[derive(Debug, Clone, Copy, PartialOrd, PartialEq, Eq)] +#[repr(u8)] +pub enum MessageType { + ParseComplete = b'1', + BindComplete = b'2', + CloseComplete = b'3', + CommandComplete = b'C', + DataRow = b'D', + ErrorResponse = b'E', + EmptyQueryResponse = b'I', + NotificationResponse = b'A', + KeyData = b'K', + NoticeResponse = b'N', + Authentication = b'R', + ParameterStatus = b'S', + RowDescription = b'T', + ReadyForQuery = b'Z', + NoData = b'n', + PortalSuspended = b's', + ParameterDescription = b't', +} + +#[derive(Debug)] +pub struct Message { + pub r#type: MessageType, + pub contents: Bytes, +} + +impl Message { + #[inline] + pub fn decode<'de, T>(self) -> Result + where + T: Deserialize<'de, ()>, + { + T::deserialize_with(self.contents, ()) + } +} + +impl TryFrom for MessageType { + type Error = Error; + + fn try_from(v: u8) -> Result { + // https://www.postgresql.org/docs/current/protocol-message-formats.html + + Ok(match v { + b'1' => MessageType::ParseComplete, + b'2' => MessageType::BindComplete, + b'3' => MessageType::CloseComplete, + b'C' => MessageType::CommandComplete, + b'D' => MessageType::DataRow, + b'E' => MessageType::ErrorResponse, + b'I' => MessageType::EmptyQueryResponse, + b'A' => MessageType::NotificationResponse, + b'K' => MessageType::KeyData, + b'N' => MessageType::NoticeResponse, + b'R' => MessageType::Authentication, + b'S' => MessageType::ParameterStatus, + b'T' => MessageType::RowDescription, + b'Z' => MessageType::ReadyForQuery, + b'n' => MessageType::NoData, + b's' => MessageType::PortalSuspended, + b't' => MessageType::ParameterDescription, + + _ => { + return Err(Error::configuration_msg(format!( + "unknown message type: {:?}", + v as char + ))); + } + }) + } +} diff --git a/sqlx-postgres/src/protocol/close.rs b/sqlx-postgres/src/protocol/close.rs new file mode 100644 index 00000000..95e916b9 --- /dev/null +++ b/sqlx-postgres/src/protocol/close.rs @@ -0,0 +1,36 @@ +use sqlx_core::io::Serialize; +use sqlx_core::Result; + +use crate::io::PgBufMutExt; + +const CLOSE_PORTAL: u8 = b'P'; +const CLOSE_STATEMENT: u8 = b'S'; + +#[derive(Debug)] +#[allow(dead_code)] +pub enum Close { + Statement(u32), + Portal(u32), +} + +impl Serialize<'_, ()> for Close { + fn serialize_with(&self, buf: &mut Vec, _: ()) -> Result<()> { + // 15 bytes for 1-digit statement/portal IDs + buf.reserve(20); + buf.push(b'C'); + + buf.write_length_prefixed(|buf| match self { + Close::Statement(id) => { + buf.push(CLOSE_STATEMENT); + buf.write_statement_name(*id); + } + + Close::Portal(id) => { + buf.push(CLOSE_PORTAL); + buf.write_portal_name(Some(*id)); + } + }); + + Ok(()) + } +} diff --git a/sqlx-postgres/src/protocol/flush.rs b/sqlx-postgres/src/protocol/flush.rs new file mode 100644 index 00000000..6967a1e5 --- /dev/null +++ b/sqlx-postgres/src/protocol/flush.rs @@ -0,0 +1,18 @@ +use sqlx_core::io::Serialize; +use sqlx_core::Result; + +// The Flush message does not cause any specific output to be generated, +// but forces the backend to deliver any data pending in its output buffers. + +// A Flush must be sent after any extended-query command except Sync, if the +// frontend wishes to examine the results of that command before issuing more commands. + +#[derive(Debug)] +pub struct Flush; + +impl Serialize<'_, ()> for Flush { + fn serialize_with(&self, buf: &mut Vec, _: ()) -> Result<()> { + buf.push(b'H'); + buf.extend(&4_i32.to_be_bytes()); + } +} diff --git a/sqlx-postgres/src/protocol/notification.rs b/sqlx-postgres/src/protocol/notification.rs new file mode 100644 index 00000000..db200fee --- /dev/null +++ b/sqlx-postgres/src/protocol/notification.rs @@ -0,0 +1,28 @@ +use bytes::{Buf, Bytes}; +use bytestring::ByteString; +use sqlx_core::io::BufExt; +use sqlx_core::io::Deserialize; +use sqlx_core::Result; + +#[derive(Debug)] +pub struct Notification { + pub(crate) process_id: u32, + pub(crate) channel: ByteString, + pub(crate) payload: ByteString, +} + +impl Deserialize<'_, ()> for Notification { + fn deserialize_with(mut buf: Bytes, _: ()) -> Result { + let process_id = buf.get_u32(); + + // UNSAFE: This message will not be read. + #[allow(unsafe_code)] + let channel = unsafe { buf.get_str_nul_unchecked()? }; + + // UNSAFE: This message will not be read. + #[allow(unsafe_code)] + let payload = unsafe { buf.get_str_nul_unchecked()? }; + + Ok(Self { process_id, channel, payload }) + } +} diff --git a/sqlx-postgres/src/protocol/response.rs b/sqlx-postgres/src/protocol/response.rs new file mode 100644 index 00000000..93edf7f3 --- /dev/null +++ b/sqlx-postgres/src/protocol/response.rs @@ -0,0 +1,190 @@ +use std::str::from_utf8; + +use bytes::Bytes; +use memchr::memchr; +use sqlx_core::io::Deserialize; +use sqlx_core::Error; +use sqlx_core::Result; + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +#[repr(u8)] +pub enum PgSeverity { + Panic, + Fatal, + Error, + Warning, + Notice, + Debug, + Info, + Log, +} + +impl PgSeverity { + #[inline] + pub fn is_error(self) -> bool { + matches!(self, Self::Panic | Self::Fatal | Self::Error) + } +} + +impl std::convert::TryFrom<&str> for PgSeverity { + type Error = Error; + + fn try_from(s: &str) -> Result { + let result = match s { + "PANIC" => PgSeverity::Panic, + "FATAL" => PgSeverity::Fatal, + "ERROR" => PgSeverity::Error, + "WARNING" => PgSeverity::Warning, + "NOTICE" => PgSeverity::Notice, + "DEBUG" => PgSeverity::Debug, + "INFO" => PgSeverity::Info, + "LOG" => PgSeverity::Log, + + severity => { + return Err(Error::configuration_msg(format!("unknown severity: {:?}", severity))); + } + }; + + Ok(result) + } +} + +#[derive(Debug)] +pub struct Notice { + storage: Bytes, + severity: PgSeverity, + message: (u16, u16), + code: (u16, u16), +} + +impl Notice { + #[inline] + pub fn severity(&self) -> PgSeverity { + self.severity + } + + #[inline] + pub fn code(&self) -> &str { + self.get_cached_str(self.code) + } + + #[inline] + pub fn message(&self) -> &str { + self.get_cached_str(self.message) + } + + // Field descriptions available here: + // https://www.postgresql.org/docs/current/protocol-error-fields.html + + #[inline] + pub fn get(&self, ty: u8) -> Option<&str> { + self.get_raw(ty).and_then(|v| from_utf8(v).ok()) + } + + pub fn get_raw(&self, ty: u8) -> Option<&[u8]> { + self.fields() + .filter(|(field, _)| *field == ty) + .map(|(_, (start, end))| &self.storage[start as usize..end as usize]) + .next() + } +} + +impl Notice { + #[inline] + fn fields(&self) -> Fields<'_> { + Fields { storage: &self.storage, offset: 0 } + } + + #[inline] + fn get_cached_str(&self, cache: (u16, u16)) -> &str { + // unwrap: this cannot fail at this stage + from_utf8(&self.storage[cache.0 as usize..cache.1 as usize]).unwrap() + } +} + +impl Deserialize<'_, ()> for Notice { + fn deserialize_with(mut buf: Bytes, _: ()) -> Result { + // In order to support PostgreSQL 9.5 and older we need to parse the localized S field. + // Newer versions additionally come with the V field that is guaranteed to be in English. + // We thus read both versions and prefer the unlocalized one if available. + const DEFAULT_SEVERITY: PgSeverity = PgSeverity::Log; + let mut severity_v = None; + let mut severity_s = None; + let mut message = (0, 0); + let mut code = (0, 0); + + // we cache the three always present fields + // this enables to keep the access time down for the fields most likely accessed + + let fields = Fields { storage: &buf, offset: 0 }; + + for (field, v) in fields { + if message.0 != 0 && code.0 != 0 { + // stop iterating when we have the 3 fields we were looking for + // we assume V (severity) was the first field as it should be + break; + } + + use std::convert::TryInto; + match field { + b'S' => { + // Discard potential errors, because the message might be localized + severity_s = + from_utf8(&buf[v.0 as usize..v.1 as usize]).unwrap().try_into().ok(); + } + + b'V' => { + // Propagate errors here, because V is not localized and thus we are missing a possible + // variant. + severity_v = + Some(from_utf8(&buf[v.0 as usize..v.1 as usize]).unwrap().try_into()?); + } + + b'M' => { + message = v; + } + + b'C' => { + code = v; + } + + _ => {} + } + } + + Ok(Self { + severity: severity_v.or(severity_s).unwrap_or(DEFAULT_SEVERITY), + message, + code, + storage: buf, + }) + } +} + +/// An iterator over each field in the Error (or Notice) response. +struct Fields<'a> { + storage: &'a [u8], + offset: u16, +} + +impl<'a> Iterator for Fields<'a> { + type Item = (u8, (u16, u16)); + + fn next(&mut self) -> Option { + // The fields in the response body are sequentially stored as [tag][string], + // ending in a final, additional [nul] + + let ty = self.storage[self.offset as usize]; + + if ty == 0 { + return None; + } + + let nul = memchr(b'\0', &self.storage[(self.offset + 1) as usize..])? as u16; + let offset = self.offset; + + self.offset += nul + 2; + + Some((ty, (offset + 1, offset + nul + 1))) + } +} diff --git a/sqlx-postgres/src/protocol/startup.rs b/sqlx-postgres/src/protocol/startup.rs new file mode 100644 index 00000000..16433e71 --- /dev/null +++ b/sqlx-postgres/src/protocol/startup.rs @@ -0,0 +1,64 @@ +use sqlx_core::io::Serialize; +use sqlx_core::io::WriteExt; +use sqlx_core::Result; + +use crate::io::PgBufMutExt; + +// To begin a session, a frontend opens a connection to the server and sends a startup message. +// This message includes the names of the user and of the database the user wants to connect to; +// it also identifies the particular protocol version to be used. + +// Optionally, the startup message can include additional settings for run-time parameters. + +#[derive(Debug)] +pub struct Startup<'a> { + /// The database user name to connect as. Required; there is no default. + pub username: Option<&'a str>, + + /// The database to connect to. Defaults to the user name. + pub database: Option<&'a str>, + + /// Additional start-up params. + /// + pub params: &'a [(&'a str, &'a str)], +} + +impl Serialize<'_, ()> for Startup<'_> { + fn serialize_with(&self, buf: &mut Vec, _: ()) -> Result<()> { + buf.reserve(120); + + buf.write_length_prefixed(|buf| { + // The protocol version number. The most significant 16 bits are the + // major version number (3 for the protocol described here). The least + // significant 16 bits are the minor version number (0 + // for the protocol described here) + buf.extend(&196_608_i32.to_be_bytes()); + + if let Some(username) = self.username { + // The database user name to connect as. + encode_startup_param(buf, "user", username); + } + + if let Some(database) = self.database { + // The database to connect to. Defaults to the user name. + encode_startup_param(buf, "database", database); + } + + for (name, value) in self.params { + encode_startup_param(buf, name, value); + } + + // A zero byte is required as a terminator + // after the last name/value pair. + buf.push(0); + }); + + Ok(()) + } +} + +#[inline] +fn encode_startup_param(buf: &mut Vec, name: &str, value: &str) { + buf.write_str_nul(name); + buf.write_str_nul(value); +} diff --git a/sqlx-postgres/src/protocol/terminate.rs b/sqlx-postgres/src/protocol/terminate.rs new file mode 100644 index 00000000..8b5c29e3 --- /dev/null +++ b/sqlx-postgres/src/protocol/terminate.rs @@ -0,0 +1,14 @@ +use sqlx_core::io::Serialize; +use sqlx_core::Result; + +#[derive(Debug)] +pub struct Terminate; + +impl Serialize<'_, ()> for Terminate { + fn serialize_with(&self, buf: &mut Vec, _: ()) -> Result<()> { + buf.push(b'X'); + buf.extend(&4_u32.to_be_bytes()); + + Ok(()) + } +} diff --git a/sqlx/Cargo.toml b/sqlx/Cargo.toml index e4d67ece..ef2da2ce 100644 --- a/sqlx/Cargo.toml +++ b/sqlx/Cargo.toml @@ -33,7 +33,13 @@ mysql = ["sqlx-mysql"] mysql-async = ["async", "mysql", "sqlx-mysql/async"] mysql-blocking = ["blocking", "mysql", "sqlx-mysql/blocking"] +# Postgres +postgres = ["sqlx-postgres"] +postgres-async = ["async", "postgres", "sqlx-postgres/async"] +postgres-blocking = ["blocking", "postgres", "sqlx-postgres/blocking"] + [dependencies] sqlx-core = { version = "0.6.0-pre", path = "../sqlx-core" } sqlx-mysql = { version = "0.6.0-pre", path = "../sqlx-mysql", optional = true } +sqlx-postgres = { version = "0.6.0-pre", path = "../sqlx-postgres", optional = true } futures-util = { version = "0.3", optional = true, features = ["io"] } diff --git a/sqlx/src/lib.rs b/sqlx/src/lib.rs index ce58a159..b6ff782e 100644 --- a/sqlx/src/lib.rs +++ b/sqlx/src/lib.rs @@ -51,6 +51,9 @@ mod runtime; #[cfg(feature = "mysql")] pub mod mysql; +#[cfg(feature = "postgres")] +pub mod postgres; + #[cfg(feature = "blocking")] pub use blocking::Blocking; pub use runtime::DefaultRuntime; diff --git a/sqlx/src/postgres/blocking.rs b/sqlx/src/postgres/blocking.rs new file mode 100644 index 00000000..217280ea --- /dev/null +++ b/sqlx/src/postgres/blocking.rs @@ -0,0 +1,2 @@ +mod connection; +mod options; diff --git a/sqlx/src/postgres/blocking/connection.rs b/sqlx/src/postgres/blocking/connection.rs new file mode 100644 index 00000000..3d95032b --- /dev/null +++ b/sqlx/src/postgres/blocking/connection.rs @@ -0,0 +1,59 @@ +use crate::blocking::{Close, Connect, Connection, Runtime}; +use crate::postgres::connection::PostgresConnection; +use crate::{Blocking, Result}; + +impl PostgresConnection { + /// Open a new database connection. + /// + /// For detailed information, refer to the async version of + /// this: [`connect`](#method.connect). + /// + /// Implemented with [`Connect::connect`]. + #[inline] + pub fn connect(url: &str) -> Result { + sqlx_postgres::PostgresConnection::::connect(url).map(Self) + } + + /// Checks if a connection to the database is still valid. + /// + /// For detailed information, refer to the async version of + /// this: [`ping`](#method.ping). + /// + /// Implemented with [`Connection::ping`]. + #[inline] + pub fn ping(&mut self) -> Result<()> { + self.0.ping() + } + + /// Explicitly close this database connection. + /// + /// For detailed information, refer to the async version of + /// this: [`close`](#method.close). + /// + /// Implemented with [`Close::close`]. + #[inline] + pub fn close(self) -> Result<()> { + self.0.close() + } +} + +impl Close for PostgresConnection { + #[inline] + fn close(self) -> Result<()> { + self.0.close() + } +} + +impl Connect for PostgresConnection { + #[inline] + fn connect(url: &str) -> Result { + sqlx_postgres::PostgresConnection::::connect(url).map(Self) + } +} + +impl Connection for PostgresConnection { + #[inline] + fn ping(&mut self) -> Result<()> { + self.0.ping() + } +} diff --git a/sqlx/src/postgres/blocking/options.rs b/sqlx/src/postgres/blocking/options.rs new file mode 100644 index 00000000..42ebfc6c --- /dev/null +++ b/sqlx/src/postgres/blocking/options.rs @@ -0,0 +1,28 @@ +use crate::blocking::{ConnectOptions, Runtime}; +use crate::postgres::{PostgresConnectOptions, PostgresConnection}; +use crate::{Blocking, Result}; + +impl PostgresConnectOptions { + /// Open a new database connection with the configured connection options. + /// + /// For detailed information, refer to the async version of + /// this: [`connect`](#method.connect). + /// + /// Implemented with [`ConnectOptions::connect`]. + #[inline] + pub fn connect(&self) -> Result> { + as ConnectOptions>::connect(&self.0) + .map(PostgresConnection::) + } +} + +impl ConnectOptions for PostgresConnectOptions { + #[inline] + fn connect(&self) -> Result + where + Self::Connection: Sized, + { + as ConnectOptions>::connect(&self.0) + .map(PostgresConnection::) + } +} diff --git a/sqlx/src/postgres/connection.rs b/sqlx/src/postgres/connection.rs new file mode 100644 index 00000000..c3b67b27 --- /dev/null +++ b/sqlx/src/postgres/connection.rs @@ -0,0 +1,93 @@ +use std::fmt::{self, Debug, Formatter}; + +#[cfg(feature = "async")] +use futures_util::future::{BoxFuture, FutureExt}; + +use super::{Postgres, PostgresConnectOptions}; +#[cfg(feature = "async")] +use crate::{Async, Result}; +use crate::{Close, Connect, Connection, DefaultRuntime, Runtime}; + +/// A single connection (also known as a session) to a MySQL database server. +#[allow(clippy::module_name_repetitions)] +pub struct PostgresConnection( + pub(super) sqlx_postgres::PostgresConnection, +); + +#[cfg(feature = "async")] +impl PostgresConnection { + /// Open a new database connection. + /// + /// A value of [`PostgresConnectOptions`] is parsed from the provided + /// connection `url`. + /// + /// ```text + /// postgres://[[user[:password]@]host][/database][?properties] + /// ``` + /// + /// Implemented with [`Connect::connect`][crate::Connect::connect]. + pub async fn connect(url: &str) -> Result { + sqlx_postgres::PostgresConnection::::connect(url).await.map(Self) + } + + /// Checks if a connection to the database is still valid. + /// + /// Implemented with [`Connection::ping`][crate::Connection::ping]. + pub async fn ping(&mut self) -> Result<()> { + self.0.ping().await + } + + /// Explicitly close this database connection. + /// + /// This method is **not required** for safe and consistent operation. However, it is + /// recommended to call it instead of letting a connection `drop` as MySQL + /// will be faster at cleaning up resources. + /// + /// Implemented with [`Close::close`][crate::Close::close]. + pub async fn close(self) -> Result<()> { + self.0.close().await + } +} + +impl Debug for PostgresConnection { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{:?}", self.0) + } +} + +impl Close for PostgresConnection { + #[cfg(feature = "async")] + #[inline] + fn close(self) -> BoxFuture<'static, Result<()>> + where + Rt: Async, + { + self.close().boxed() + } +} + +impl Connect for PostgresConnection { + type Options = PostgresConnectOptions; + + #[cfg(feature = "async")] + #[inline] + fn connect(url: &str) -> BoxFuture<'_, Result> + where + Rt: Async, + { + Self::connect(url).boxed() + } +} + +impl Connection for PostgresConnection { + type Database = Postgres; + + #[cfg(feature = "async")] + #[inline] + fn ping(&mut self) -> BoxFuture<'_, Result<()>> + where + Rt: Async, + { + self.ping().boxed() + } +} diff --git a/sqlx/src/postgres/database.rs b/sqlx/src/postgres/database.rs new file mode 100644 index 00000000..c2a974d8 --- /dev/null +++ b/sqlx/src/postgres/database.rs @@ -0,0 +1,15 @@ +use sqlx_core::HasOutput; + +use super::PostgresConnection; +use crate::{Database, Runtime}; + +#[derive(Debug)] +pub struct Postgres; + +impl Database for Postgres { + type Connection = PostgresConnection; +} + +impl<'x> HasOutput<'x> for Postgres { + type Output = &'x mut Vec; +} diff --git a/sqlx/src/postgres/options.rs b/sqlx/src/postgres/options.rs new file mode 100644 index 00000000..dd4e4a88 --- /dev/null +++ b/sqlx/src/postgres/options.rs @@ -0,0 +1,95 @@ +use std::fmt::{self, Debug, Formatter}; +use std::str::FromStr; + +#[cfg(feature = "async")] +use futures_util::future::{BoxFuture, FutureExt}; + +use crate::postgres::PostgresConnection; +#[cfg(feature = "async")] +use crate::Async; +use crate::{ConnectOptions, DefaultRuntime, Error, Result, Runtime}; + +mod builder; +mod getters; + +/// Options which can be used to configure how a MySQL connection is opened. +#[allow(clippy::module_name_repetitions)] +pub struct PostgresConnectOptions( + pub(super) sqlx_postgres::PostgresConnectOptions, +); + +impl PostgresConnectOptions { + /// Creates a default set of connection options. + /// + /// Implemented with [`Default`](#impl-Default). + #[inline] + pub fn new() -> Self { + Self::default() + } + + /// Parses connection options from a connection URL. + /// + /// ```text + /// postgres://[[user[:password]@]host][/database][?properties] + /// ``` + /// + /// Implemented with [`FromStr`](#impl-FromStr). + /// + #[inline] + pub fn parse(url: &str) -> Result { + Ok(Self(url.parse()?)) + } +} + +#[cfg(feature = "async")] +impl PostgresConnectOptions { + /// Open a new database connection with the configured connection options. + /// + /// Implemented with [`ConnectOptions::connect`]. + #[inline] + pub async fn connect(&self) -> Result> { + as ConnectOptions>::connect(&self.0) + .await + .map(PostgresConnection) + } +} + +impl ConnectOptions for PostgresConnectOptions { + type Connection = PostgresConnection; + + #[cfg(feature = "async")] + #[inline] + fn connect(&self) -> BoxFuture<'_, Result> + where + Self::Connection: Sized, + Rt: Async, + { + self.connect().boxed() + } +} + +impl Debug for PostgresConnectOptions { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{:?}", self.0) + } +} + +impl Default for PostgresConnectOptions { + fn default() -> Self { + Self(sqlx_postgres::PostgresConnectOptions::::default()) + } +} + +impl Clone for PostgresConnectOptions { + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} + +impl FromStr for PostgresConnectOptions { + type Err = Error; + + fn from_str(url: &str) -> Result { + Ok(Self(url.parse()?)) + } +} diff --git a/sqlx/src/postgres/options/builder.rs b/sqlx/src/postgres/options/builder.rs new file mode 100644 index 00000000..d1630601 --- /dev/null +++ b/sqlx/src/postgres/options/builder.rs @@ -0,0 +1,74 @@ +use std::path::Path; + +use super::PostgresConnectOptions; +use crate::Runtime; + +impl PostgresConnectOptions { + /// Sets the hostname of the database server. + /// + /// If the hostname begins with a slash (`/`), it is interpreted as the absolute path + /// to a Unix domain socket file instead of a hostname of a server. + /// + /// Defaults to `localhost`. + /// + #[inline] + pub fn host(&mut self, host: impl AsRef) -> &mut Self { + self.0.host(host); + self + } + + /// Sets the path of the Unix domain socket to connect to. + /// + /// Overrides [`host()`](#method.host) and [`port()`](#method.port). + /// + #[inline] + pub fn socket(&mut self, socket: impl AsRef) -> &mut Self { + self.0.socket(socket); + self + } + + /// Sets the TCP port number of the database server. + /// + /// Defaults to `3306`. + /// + #[inline] + pub fn port(&mut self, port: u16) -> &mut Self { + self.0.port(port); + self + } + + /// Sets the username to be used for authentication. + // FIXME: Specify what happens when you do NOT set this + pub fn username(&mut self, username: impl AsRef) -> &mut Self { + self.0.username(username); + self + } + + /// Sets the password to be used for authentication. + #[inline] + pub fn password(&mut self, password: impl AsRef) -> &mut Self { + self.0.password(password); + self + } + + /// Sets the default database for the connection. + #[inline] + pub fn database(&mut self, database: impl AsRef) -> &mut Self { + self.0.database(database); + self + } + + /// Sets the character set for the connection. + #[inline] + pub fn charset(&mut self, charset: impl AsRef) -> &mut Self { + self.0.charset(charset); + self + } + + /// Sets the timezone for the connection. + #[inline] + pub fn timezone(&mut self, timezone: impl AsRef) -> &mut Self { + self.0.timezone(timezone); + self + } +} diff --git a/sqlx/src/postgres/options/getters.rs b/sqlx/src/postgres/options/getters.rs new file mode 100644 index 00000000..ea63beb3 --- /dev/null +++ b/sqlx/src/postgres/options/getters.rs @@ -0,0 +1,62 @@ +use std::path::Path; + +use super::PostgresConnectOptions; +use crate::Runtime; + +impl PostgresConnectOptions { + /// Returns the hostname of the database server. + #[must_use] + #[inline] + pub fn get_host(&self) -> &str { + self.0.get_host() + } + + /// Returns the TCP port number of the database server. + #[must_use] + #[inline] + pub fn get_port(&self) -> u16 { + self.0.get_port() + } + + /// Returns the path to the Unix domain socket, if one is configured. + #[must_use] + #[inline] + pub fn get_socket(&self) -> Option<&Path> { + self.0.get_socket() + } + + /// Returns the default database name. + #[must_use] + #[inline] + pub fn get_database(&self) -> Option<&str> { + self.0.get_database() + } + + /// Returns the username to be used for authentication. + #[must_use] + #[inline] + pub fn get_username(&self) -> Option<&str> { + self.0.get_username() + } + + /// Returns the password to be used for authentication. + #[must_use] + #[inline] + pub fn get_password(&self) -> Option<&str> { + self.0.get_password() + } + + /// Returns the character set for the connection. + #[must_use] + #[inline] + pub fn get_charset(&self) -> &str { + self.0.get_charset() + } + + /// Returns the timezone for the connection. + #[must_use] + #[inline] + pub fn get_timezone(&self) -> &str { + self.0.get_timezone() + } +} diff --git a/x.py b/x.py index a1db09c2..f523ebab 100755 --- a/x.py +++ b/x.py @@ -148,16 +148,19 @@ def main(): # run checks run_checks("sqlx-core") run_checks("sqlx-mysql") + run_checks("sqlx-postgres") run_checks("sqlx") # run checks run_checks("sqlx-core", cmd="clippy") run_checks("sqlx-mysql", cmd="clippy") + run_checks("sqlx-postgres", cmd="clippy") run_checks("sqlx", cmd="clippy") # run docs (only if asked) run_docs("sqlx-core") run_docs("sqlx-mysql") + run_docs("sqlx-postgres") run_docs("sqlx") # run unit tests, collect test binary filenames