diff --git a/Cargo.lock b/Cargo.lock index d15a3b497..b940562f1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2152,6 +2152,23 @@ dependencies = [ "url", ] +[[package]] +name = "sqlx-postgres" +version = "0.1.0-pre" +dependencies = [ + "atoi", + "base64", + "bytes", + "futures-core", + "futures-util", + "md-5", + "memchr", + "sqlx-core2", + "sqlx-rt", + "url", + "whoami", +] + [[package]] name = "sqlx-rt" version = "0.1.1" diff --git a/Cargo.toml b/Cargo.toml index 928206924..57b1bc155 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,6 +3,7 @@ members = [ ".", "sqlx-core", "sqlx-core2", + "sqlx-postgres", "sqlx-rt", "sqlx-macros", "sqlx-test", diff --git a/sqlx-core/src/postgres/connection/establish.rs b/sqlx-core/src/postgres/connection/establish.rs deleted file mode 100644 index 86da97c5b..000000000 --- a/sqlx-core/src/postgres/connection/establish.rs +++ /dev/null @@ -1,141 +0,0 @@ -use hashbrown::HashMap; - -use crate::common::StatementCache; -use crate::error::Error; -use crate::io::Decode; -use crate::postgres::connection::{sasl, stream::PgStream, tls}; -use crate::postgres::message::{ - Authentication, BackendKeyData, MessageFormat, Password, ReadyForQuery, Startup, -}; -use crate::postgres::{PgConnectOptions, PgConnection}; - -// https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.5.7.3 -// https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.5.7.11 - -impl PgConnection { - pub(crate) async fn establish(options: &PgConnectOptions) -> Result { - let mut stream = PgStream::connect(options).await?; - - // Upgrade to TLS if we were asked to and the server supports it - tls::maybe_upgrade(&mut stream, options).await?; - - // To begin a session, a frontend opens a connection to the server - // and sends a startup message. - - stream - .send(Startup { - username: Some(&options.username), - database: options.database.as_deref(), - params: &[ - // Sets the display format for date and time values, - // as well as the rules for interpreting ambiguous date input values. - ("DateStyle", "ISO, MDY"), - // Sets the client-side encoding (character set). - // - ("client_encoding", "UTF8"), - // Sets the time zone for displaying and interpreting time stamps. - ("TimeZone", "UTC"), - // Adjust postgres to return precise values for floats - // NOTE: This is default in postgres 12+ - ("extra_float_digits", "3"), - ], - }) - .await?; - - // The server then uses this information and the contents of - // its configuration files (such as pg_hba.conf) to determine whether the connection is - // provisionally acceptable, and what additional - // authentication is required (if any). - - let mut process_id = 0; - let mut secret_key = 0; - let transaction_status; - - loop { - let message = stream.recv().await?; - match message.format { - MessageFormat::Authentication => match message.decode()? { - Authentication::Ok => { - // the authentication exchange is successfully completed - // do nothing; no more information is required to continue - } - - Authentication::CleartextPassword => { - // The frontend must now send a [PasswordMessage] containing the - // password in clear-text form. - - stream - .send(Password::Cleartext( - options.password.as_deref().unwrap_or_default(), - )) - .await?; - } - - Authentication::Md5Password(body) => { - // The frontend must now send a [PasswordMessage] containing the - // password (with user name) encrypted via MD5, then encrypted again - // using the 4-byte random salt specified in the - // [AuthenticationMD5Password] message. - - stream - .send(Password::Md5 { - username: &options.username, - password: options.password.as_deref().unwrap_or_default(), - salt: body.salt, - }) - .await?; - } - - Authentication::Sasl(body) => { - sasl::authenticate(&mut stream, options, body).await?; - } - - method => { - return Err(err_protocol!( - "unsupported authentication method: {:?}", - method - )); - } - }, - - MessageFormat::BackendKeyData => { - // provides secret-key data that the frontend must save if it wants to be - // able to issue cancel requests later - - let data: BackendKeyData = message.decode()?; - - process_id = data.process_id; - secret_key = data.secret_key; - } - - MessageFormat::ReadyForQuery => { - // start-up is completed. The frontend can now issue commands - transaction_status = - ReadyForQuery::decode(message.contents)?.transaction_status; - - break; - } - - _ => { - return Err(err_protocol!( - "establish: unexpected message: {:?}", - message.format - )) - } - } - } - - Ok(PgConnection { - stream, - process_id, - secret_key, - transaction_status, - transaction_depth: 0, - pending_ready_for_query_count: 0, - next_statement_id: 1, - cache_statement: StatementCache::new(options.statement_cache_capacity), - cache_type_oid: HashMap::new(), - cache_type_info: HashMap::new(), - }) - } -} diff --git a/sqlx-core/src/postgres/database.rs b/sqlx-core/src/postgres/database.rs deleted file mode 100644 index f3041dd4c..000000000 --- a/sqlx-core/src/postgres/database.rs +++ /dev/null @@ -1,49 +0,0 @@ -use crate::database::{Database, HasArguments, HasStatement, HasStatementCache, HasValueRef}; -use crate::postgres::arguments::PgArgumentBuffer; -use crate::postgres::value::{PgValue, PgValueRef}; -use crate::postgres::{ - PgArguments, PgColumn, PgConnection, PgDone, PgRow, PgStatement, PgTransactionManager, - PgTypeInfo, -}; - -/// PostgreSQL database driver. -#[derive(Debug)] -pub struct Postgres; - -impl Database for Postgres { - type Connection = PgConnection; - - type TransactionManager = PgTransactionManager; - - type Row = PgRow; - - type Done = PgDone; - - type Column = PgColumn; - - type TypeInfo = PgTypeInfo; - - type Value = PgValue; -} - -impl<'r> HasValueRef<'r> for Postgres { - type Database = Postgres; - - type ValueRef = PgValueRef<'r>; -} - -impl HasArguments<'_> for Postgres { - type Database = Postgres; - - type Arguments = PgArguments; - - type ArgumentBuffer = PgArgumentBuffer; -} - -impl<'q> HasStatement<'q> for Postgres { - type Database = Postgres; - - type Statement = PgStatement<'q>; -} - -impl HasStatementCache for Postgres {} diff --git a/sqlx-core/src/postgres/io/buf_mut.rs b/sqlx-core/src/postgres/io/buf_mut.rs deleted file mode 100644 index c78633a64..000000000 --- a/sqlx-core/src/postgres/io/buf_mut.rs +++ /dev/null @@ -1,51 +0,0 @@ -pub trait PgBufMutExt { - fn put_length_prefixed(&mut self, f: F) - where - F: FnOnce(&mut Vec); - - fn put_statement_name(&mut self, id: u32); - - fn put_portal_name(&mut self, id: Option); -} - -impl PgBufMutExt for Vec { - // writes a length-prefixed message, this is used when encoding nearly all messages as postgres - // wants us to send the length of the often-variable-sized messages up front - fn put_length_prefixed(&mut self, f: F) - where - F: FnOnce(&mut Vec), - { - // reserve space to write the prefixed length - let offset = self.len(); - self.extend(&[0; 4]); - - // write the main body of the message - f(self); - - // now calculate the size of what we wrote and set the length value - let size = (self.len() - offset) as i32; - self[offset..(offset + 4)].copy_from_slice(&size.to_be_bytes()); - } - - // writes a statement name by ID - #[inline] - fn put_statement_name(&mut self, id: u32) { - self.extend(b"sqlx_s_"); - - itoa::write(&mut *self, id).unwrap(); - - self.push(0); - } - - // writes a portal name by ID - #[inline] - fn put_portal_name(&mut self, id: Option) { - if let Some(id) = id { - self.extend(b"sqlx_p_"); - - itoa::write(&mut *self, id).unwrap(); - } - - self.push(0); - } -} diff --git a/sqlx-core/src/postgres/io/mod.rs b/sqlx-core/src/postgres/io/mod.rs deleted file mode 100644 index 37988b6d4..000000000 --- a/sqlx-core/src/postgres/io/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -mod buf_mut; - -pub use buf_mut::PgBufMutExt; diff --git a/sqlx-core/src/postgres/message/backend_key_data.rs b/sqlx-core/src/postgres/message/backend_key_data.rs deleted file mode 100644 index 03ae6413e..000000000 --- a/sqlx-core/src/postgres/message/backend_key_data.rs +++ /dev/null @@ -1,48 +0,0 @@ -use byteorder::{BigEndian, ByteOrder}; -use bytes::Bytes; - -use crate::error::Error; -use crate::io::Decode; - -/// Contains cancellation key data. The frontend must save these values if it -/// wishes to be able to issue `CancelRequest` messages later. -#[derive(Debug)] -pub struct BackendKeyData { - /// The process ID of this database. - pub process_id: u32, - - /// The secret key of this database. - pub secret_key: u32, -} - -impl Decode<'_> for BackendKeyData { - fn decode_with(buf: Bytes, _: ()) -> Result { - let process_id = BigEndian::read_u32(&buf); - let secret_key = BigEndian::read_u32(&buf[4..]); - - Ok(Self { - process_id, - secret_key, - }) - } -} - -#[test] -fn test_decode_backend_key_data() { - const DATA: &[u8] = b"\0\0'\xc6\x89R\xc5+"; - - let m = BackendKeyData::decode(DATA.into()).unwrap(); - - assert_eq!(m.process_id, 10182); - assert_eq!(m.secret_key, 2303903019); -} - -#[cfg(all(test, not(debug_assertions)))] -#[bench] -fn bench_decode_backend_key_data(b: &mut test::Bencher) { - const DATA: &[u8] = b"\0\0'\xc6\x89R\xc5+"; - - b.iter(|| { - BackendKeyData::decode(test::black_box(Bytes::from_static(DATA))).unwrap(); - }); -} diff --git a/sqlx-core/src/postgres/message/command_complete.rs b/sqlx-core/src/postgres/message/command_complete.rs deleted file mode 100644 index 87ed55726..000000000 --- a/sqlx-core/src/postgres/message/command_complete.rs +++ /dev/null @@ -1,81 +0,0 @@ -use atoi::atoi; -use bytes::Bytes; -use memchr::memrchr; - -use crate::error::Error; -use crate::io::Decode; - -#[derive(Debug)] -pub struct CommandComplete { - /// The command tag. This is usually a single word that identifies which SQL command - /// was completed. - tag: Bytes, -} - -impl Decode<'_> for CommandComplete { - #[inline] - fn decode_with(buf: Bytes, _: ()) -> Result { - Ok(CommandComplete { tag: buf }) - } -} - -impl CommandComplete { - /// Returns the number of rows affected. - /// If the command does not return rows (e.g., "CREATE TABLE"), returns 0. - pub fn rows_affected(&self) -> u64 { - // Look backwards for the first SPACE - memrchr(b' ', &self.tag) - // This is either a word or the number of rows affected - .and_then(|i| atoi(&self.tag[(i + 1)..])) - .unwrap_or(0) - } -} - -#[test] -fn test_decode_command_complete_for_insert() { - const DATA: &[u8] = b"INSERT 0 1214\0"; - - let cc = CommandComplete::decode(Bytes::from_static(DATA)).unwrap(); - - assert_eq!(cc.rows_affected(), 1214); -} - -#[test] -fn test_decode_command_complete_for_begin() { - const DATA: &[u8] = b"BEGIN\0"; - - let cc = CommandComplete::decode(Bytes::from_static(DATA)).unwrap(); - - assert_eq!(cc.rows_affected(), 0); -} - -#[test] -fn test_decode_command_complete_for_update() { - const DATA: &[u8] = b"UPDATE 5\0"; - - let cc = CommandComplete::decode(Bytes::from_static(DATA)).unwrap(); - - assert_eq!(cc.rows_affected(), 5); -} - -#[cfg(all(test, not(debug_assertions)))] -#[bench] -fn bench_decode_command_complete(b: &mut test::Bencher) { - const DATA: &[u8] = b"INSERT 0 1214\0"; - - b.iter(|| { - let _ = CommandComplete::decode(test::black_box(Bytes::from_static(DATA))); - }); -} - -#[cfg(all(test, not(debug_assertions)))] -#[bench] -fn bench_decode_command_complete_rows_affected(b: &mut test::Bencher) { - const DATA: &[u8] = b"INSERT 0 1214\0"; - - let data = CommandComplete::decode(Bytes::from_static(DATA)).unwrap(); - - b.iter(|| { - let _rows = test::black_box(&data).rows_affected(); - }); -} diff --git a/sqlx-core/src/postgres/message/query.rs b/sqlx-core/src/postgres/message/query.rs deleted file mode 100644 index 8f49aabc3..000000000 --- a/sqlx-core/src/postgres/message/query.rs +++ /dev/null @@ -1,27 +0,0 @@ -use crate::io::{BufMutExt, Encode}; - -#[derive(Debug)] -pub struct Query<'a>(pub &'a str); - -impl Encode<'_> for Query<'_> { - fn encode_with(&self, buf: &mut Vec, _: ()) { - let len = 4 + self.0.len() + 1; - - buf.reserve(len + 1); - buf.push(b'Q'); - buf.extend(&(len as i32).to_be_bytes()); - buf.put_str_nul(self.0); - } -} - -#[test] -fn test_encode_query() { - const EXPECTED: &[u8] = b"Q\0\0\0\rSELECT 1\0"; - - let mut buf = Vec::new(); - let m = Query("SELECT 1"); - - m.encode(&mut buf); - - assert_eq!(buf, EXPECTED); -} diff --git a/sqlx-core/src/postgres/message/startup.rs b/sqlx-core/src/postgres/message/startup.rs deleted file mode 100644 index 34f6fa4e9..000000000 --- a/sqlx-core/src/postgres/message/startup.rs +++ /dev/null @@ -1,93 +0,0 @@ -use crate::io::{BufMutExt, Encode}; -use crate::postgres::io::PgBufMutExt; - -// To begin a session, a frontend opens a connection to the server and sends a startup message. -// This message includes the names of the user and of the database the user wants to connect to; -// it also identifies the particular protocol version to be used. - -// Optionally, the startup message can include additional settings for run-time parameters. - -pub struct Startup<'a> { - /// The database user name to connect as. Required; there is no default. - pub username: Option<&'a str>, - - /// The database to connect to. Defaults to the user name. - pub database: Option<&'a str>, - - /// Additional start-up params. - /// - pub params: &'a [(&'a str, &'a str)], -} - -impl Encode<'_> for Startup<'_> { - fn encode_with(&self, buf: &mut Vec, _: ()) { - buf.reserve(120); - - buf.put_length_prefixed(|buf| { - // The protocol version number. The most significant 16 bits are the - // major version number (3 for the protocol described here). The least - // significant 16 bits are the minor version number (0 - // for the protocol described here) - buf.extend(&196_608_i32.to_be_bytes()); - - if let Some(username) = self.username { - // The database user name to connect as. - encode_startup_param(buf, "user", username); - } - - if let Some(database) = self.database { - // The database to connect to. Defaults to the user name. - encode_startup_param(buf, "database", database); - } - - for (name, value) in self.params { - encode_startup_param(buf, name, value); - } - - // A zero byte is required as a terminator - // after the last name/value pair. - buf.push(0); - }); - } -} - -#[inline] -fn encode_startup_param(buf: &mut Vec, name: &str, value: &str) { - buf.put_str_nul(name); - buf.put_str_nul(value); -} - -#[test] -fn test_encode_startup() { - const EXPECTED: &[u8] = b"\0\0\0)\0\x03\0\0user\0postgres\0database\0postgres\0\0"; - - let mut buf = Vec::new(); - let m = Startup { - username: Some("postgres"), - database: Some("postgres"), - params: &[], - }; - - m.encode(&mut buf); - - assert_eq!(buf, EXPECTED); -} - -#[cfg(all(test, not(debug_assertions)))] -#[bench] -fn bench_encode_startup(b: &mut test::Bencher) { - use test::black_box; - - let mut buf = Vec::with_capacity(128); - - b.iter(|| { - buf.clear(); - - black_box(Startup { - username: Some("postgres"), - database: Some("postgres"), - params: &[], - }) - .encode(&mut buf); - }); -} diff --git a/sqlx-core/src/postgres/message/sync.rs b/sqlx-core/src/postgres/message/sync.rs deleted file mode 100644 index bc30114ef..000000000 --- a/sqlx-core/src/postgres/message/sync.rs +++ /dev/null @@ -1,11 +0,0 @@ -use crate::io::Encode; - -#[derive(Debug)] -pub struct Sync; - -impl Encode<'_> for Sync { - fn encode_with(&self, buf: &mut Vec, _: ()) { - buf.push(b'S'); - buf.extend(&4_i32.to_be_bytes()); - } -} diff --git a/sqlx-core/src/postgres/message/terminate.rs b/sqlx-core/src/postgres/message/terminate.rs deleted file mode 100644 index 98e41fdba..000000000 --- a/sqlx-core/src/postgres/message/terminate.rs +++ /dev/null @@ -1,10 +0,0 @@ -use crate::io::Encode; - -pub struct Terminate; - -impl Encode<'_> for Terminate { - fn encode_with(&self, buf: &mut Vec, _: ()) { - buf.push(b'X'); - buf.extend(&4_u32.to_be_bytes()); - } -} diff --git a/sqlx-core/src/postgres/options/connect.rs b/sqlx-core/src/postgres/options/connect.rs deleted file mode 100644 index 6b6a8e3b6..000000000 --- a/sqlx-core/src/postgres/options/connect.rs +++ /dev/null @@ -1,15 +0,0 @@ -use crate::connection::ConnectOptions; -use crate::error::Error; -use crate::postgres::{PgConnectOptions, PgConnection}; -use futures_core::future::BoxFuture; - -impl ConnectOptions for PgConnectOptions { - type Connection = PgConnection; - - fn connect(&self) -> BoxFuture<'_, Result> - where - Self::Connection: Sized, - { - Box::pin(PgConnection::establish(self)) - } -} diff --git a/sqlx-postgres/Cargo.toml b/sqlx-postgres/Cargo.toml new file mode 100644 index 000000000..a7cf57885 --- /dev/null +++ b/sqlx-postgres/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "sqlx-postgres" +version = "0.1.0-pre" +edition = "2018" +authors = ["Ryan Leckey "] + +[dependencies] +sqlx-rt = { path = "../sqlx-rt", version = "0.1.1" } +sqlx-core = { package = "sqlx-core2", path = "../sqlx-core2", version = "0.4.0-beta.2" } +bytes = "0.5.6" +md-5 = "0.9.1" +atoi = "0.3.2" +memchr = "2.3.3" +whoami = "0.9.0" +url = "2.1.1" +futures-core = "0.3.5" +futures-util = "0.3.5" +base64 = "0.12.3" diff --git a/sqlx-core/src/postgres/message/authentication.rs b/sqlx-postgres/src/codec/backend/authentication.rs similarity index 77% rename from sqlx-core/src/postgres/message/authentication.rs rename to sqlx-postgres/src/codec/backend/authentication.rs index 47625acfd..59c2f75d3 100644 --- a/sqlx-core/src/postgres/message/authentication.rs +++ b/sqlx-postgres/src/codec/backend/authentication.rs @@ -1,10 +1,7 @@ -use std::str::from_utf8; - use bytes::{Buf, Bytes}; use memchr::memchr; - -use crate::error::Error; -use crate::io::Decode; +use sqlx_core::{error::Error, io::Decode}; +use std::str::from_utf8; // On startup, the server sends an appropriate authentication request message, // to which the frontend must reply with an appropriate authentication @@ -21,7 +18,7 @@ use crate::io::Decode; // #[derive(Debug)] -pub enum Authentication { +pub(crate) enum Authentication { /// The authentication exchange is successfully completed. Ok, @@ -77,7 +74,10 @@ impl Decode<'_> for Authentication { 12 => Authentication::SaslFinal(AuthenticationSaslFinal::decode(buf)?), ty => { - return Err(err_protocol!("unknown authentication method: {}", ty)); + return Err(Error::protocol_msg(format!( + "unknown authentication method: {}", + ty + ))); } }) } @@ -85,23 +85,23 @@ impl Decode<'_> for Authentication { /// Body of [Authentication::Md5Password]. #[derive(Debug)] -pub struct AuthenticationMd5Password { - pub salt: [u8; 4], +pub(crate) struct AuthenticationMd5Password { + pub(crate) salt: [u8; 4], } /// Body of [Authentication::Sasl]. #[derive(Debug)] -pub struct AuthenticationSasl(Bytes); +pub(crate) struct AuthenticationSasl(Bytes); impl AuthenticationSasl { #[inline] - pub fn mechanisms(&self) -> SaslMechanisms<'_> { + pub(crate) fn mechanisms(&self) -> SaslMechanisms<'_> { SaslMechanisms(&self.0) } } /// An iterator over the SASL authentication mechanisms provided by the server. -pub struct SaslMechanisms<'a>(&'a [u8]); +pub(crate) struct SaslMechanisms<'a>(&'a [u8]); impl<'a> Iterator for SaslMechanisms<'a> { type Item = &'a str; @@ -120,11 +120,11 @@ impl<'a> Iterator for SaslMechanisms<'a> { } #[derive(Debug)] -pub struct AuthenticationSaslContinue { - pub salt: Vec, - pub iterations: u32, - pub nonce: String, - pub message: String, +pub(crate) struct AuthenticationSaslContinue { + pub(crate) salt: Vec, + pub(crate) iterations: u32, + pub(crate) nonce: String, + pub(crate) message: String, } impl Decode<'_> for AuthenticationSaslContinue { @@ -167,8 +167,8 @@ impl Decode<'_> for AuthenticationSaslContinue { } #[derive(Debug)] -pub struct AuthenticationSaslFinal { - pub verifier: Vec, +pub(crate) struct AuthenticationSaslFinal { + pub(crate) verifier: Vec, } impl Decode<'_> for AuthenticationSaslFinal { @@ -187,3 +187,40 @@ impl Decode<'_> for AuthenticationSaslFinal { Ok(Self { verifier }) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn decode() -> Result<(), Error> { + // \0\0\0\x05\xccSZ\x7f + + let v = Authentication::decode(Bytes::from_static(b"\0\0\0\x05\xccSZ\x7f"))?; + + assert!(matches!( + v, + Authentication::Md5Password(AuthenticationMd5Password { + salt: [204, 83, 90, 127], + }) + )); + + Ok(()) + } +} + +#[cfg(all(test, not(debug_assertions)))] +mod bench { + use super::*; + + #[bench] + fn decode(b: &mut test::Bencher) { + use test::black_box; + + let mut buf = Bytes::from_static(b"\0\0\0\x05\xccSZ\x7f"); + + b.iter(|| { + let _ = Authentication::decode(black_box(buf.clone())).unwrap(); + }); + } +} diff --git a/sqlx-postgres/src/codec/backend/backend_key_data.rs b/sqlx-postgres/src/codec/backend/backend_key_data.rs new file mode 100644 index 000000000..1403dd5b1 --- /dev/null +++ b/sqlx-postgres/src/codec/backend/backend_key_data.rs @@ -0,0 +1,40 @@ +use bytes::{Buf, Bytes}; +use sqlx_core::{error::Error, io::Decode}; + +/// Contains cancellation key data. The frontend must save these values if it +/// wishes to be able to issue `CancelRequest` messages later. +#[derive(Debug)] +pub(crate) struct BackendKeyData { + /// The process ID of this database. + pub(crate) process_id: u32, + + /// The secret key of this database. + pub(crate) secret_key: u32, +} + +impl Decode<'_> for BackendKeyData { + fn decode_with(mut buf: Bytes, _: ()) -> Result { + let process_id = buf.get_u32(); + let secret_key = buf.get_u32(); + + Ok(Self { + process_id, + secret_key, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn decode() { + const DATA: &[u8] = b"\0\0'\xc6\x89R\xc5+"; + + let m = BackendKeyData::decode(DATA.into()).unwrap(); + + assert_eq!(m.process_id, 10182); + assert_eq!(m.secret_key, 2303903019); + } +} diff --git a/sqlx-postgres/src/codec/backend/command_complete.rs b/sqlx-postgres/src/codec/backend/command_complete.rs new file mode 100644 index 000000000..f4a625cdd --- /dev/null +++ b/sqlx-postgres/src/codec/backend/command_complete.rs @@ -0,0 +1,85 @@ +use atoi::atoi; +use bytes::Bytes; +use memchr::memrchr; +use sqlx_core::{error::Error, io::Decode}; + +#[derive(Debug)] +pub(crate) struct CommandComplete { + /// The command tag. This is usually a single word that identifies which SQL command + /// was completed. + tag: Bytes, +} + +impl Decode<'_> for CommandComplete { + #[inline] + fn decode_with(buf: Bytes, _: ()) -> Result { + Ok(CommandComplete { tag: buf }) + } +} + +impl CommandComplete { + /// Returns the number of rows affected. + /// If the command does not return rows (e.g., "CREATE TABLE"), returns 0. + pub(crate) fn rows_affected(&self) -> u64 { + // Look backwards for the first SPACE + memrchr(b' ', &self.tag) + // This is either a word or the number of rows affected + .and_then(|i| atoi(&self.tag[(i + 1)..])) + .unwrap_or(0) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn decode_insert() { + const DATA: &[u8] = b"INSERT 0 1214\0"; + + let cc = CommandComplete::decode(Bytes::from_static(DATA)).unwrap(); + + assert_eq!(cc.rows_affected(), 1214); + } + + #[test] + fn decode_begin() { + const DATA: &[u8] = b"BEGIN\0"; + + let cc = CommandComplete::decode(Bytes::from_static(DATA)).unwrap(); + + assert_eq!(cc.rows_affected(), 0); + } + + #[test] + fn decode_update() { + const DATA: &[u8] = b"UPDATE 5\0"; + + let cc = CommandComplete::decode(Bytes::from_static(DATA)).unwrap(); + + assert_eq!(cc.rows_affected(), 5); + } +} + +#[cfg(all(test, not(debug_assertions)))] +mod bench { + #[bench] + fn decode(b: &mut test::Bencher) { + const DATA: &[u8] = b"INSERT 0 1214\0"; + + b.iter(|| { + let _ = CommandComplete::decode(test::black_box(Bytes::from_static(DATA))); + }); + } + + #[bench] + fn rows_affected(b: &mut test::Bencher) { + const DATA: &[u8] = b"INSERT 0 1214\0"; + + let data = CommandComplete::decode(Bytes::from_static(DATA)).unwrap(); + + b.iter(|| { + let _rows = test::black_box(&data).rows_affected(); + }); + } +} diff --git a/sqlx-postgres/src/codec/backend/mod.rs b/sqlx-postgres/src/codec/backend/mod.rs new file mode 100644 index 000000000..d820a2d3f --- /dev/null +++ b/sqlx-postgres/src/codec/backend/mod.rs @@ -0,0 +1,88 @@ +use bytes::Bytes; +use sqlx_core::{error::Error, io::Decode}; + +mod authentication; +mod backend_key_data; +mod command_complete; +// mod data_row; +// mod notice; +// mod notification; +// mod parameter_description; +mod ready_for_query; +// mod row_description; +// mod ssl_request; + +pub(crate) use authentication::{Authentication, AuthenticationMd5Password}; +pub(crate) use backend_key_data::BackendKeyData; +pub(crate) use command_complete::CommandComplete; +pub(crate) use ready_for_query::{ReadyForQuery, TransactionStatus}; + +// https://www.postgresql.org/docs/current/protocol-message-formats.html + +#[derive(Debug)] +#[repr(u8)] +pub(crate) enum MessageFormat { + Authentication, + BackendKeyData, + BindComplete, + CloseComplete, + CommandComplete, + DataRow, + EmptyQueryResponse, + ErrorResponse, + NoData, + NoticeResponse, + NotificationResponse, + ParameterDescription, + ParameterStatus, + ParseComplete, + PortalSuspended, + ReadyForQuery, + RowDescription, +} + +#[derive(Debug)] +pub(crate) struct RawMessage { + pub(crate) format: MessageFormat, + pub(crate) contents: Bytes, +} + +impl RawMessage { + #[inline] + pub(crate) fn decode<'de, T>(self) -> Result + where + T: Decode<'de>, + { + T::decode(self.contents) + } +} + +impl MessageFormat { + pub(crate) fn try_from_u8(v: u8) -> Result { + Ok(match v { + b'1' => MessageFormat::ParseComplete, + b'2' => MessageFormat::BindComplete, + b'3' => MessageFormat::CloseComplete, + b'C' => MessageFormat::CommandComplete, + b'D' => MessageFormat::DataRow, + b'E' => MessageFormat::ErrorResponse, + b'I' => MessageFormat::EmptyQueryResponse, + b'A' => MessageFormat::NotificationResponse, + b'K' => MessageFormat::BackendKeyData, + b'N' => MessageFormat::NoticeResponse, + b'R' => MessageFormat::Authentication, + b'S' => MessageFormat::ParameterStatus, + b'T' => MessageFormat::RowDescription, + b'Z' => MessageFormat::ReadyForQuery, + b'n' => MessageFormat::NoData, + b's' => MessageFormat::PortalSuspended, + b't' => MessageFormat::ParameterDescription, + + _ => { + return Err(Error::Protocol( + format!("unknown message type: {:?}", v as char).into(), + )) + } + }) + } +} diff --git a/sqlx-core/src/postgres/message/ready_for_query.rs b/sqlx-postgres/src/codec/backend/ready_for_query.rs similarity index 59% rename from sqlx-core/src/postgres/message/ready_for_query.rs rename to sqlx-postgres/src/codec/backend/ready_for_query.rs index 791d01b20..a3067aac9 100644 --- a/sqlx-core/src/postgres/message/ready_for_query.rs +++ b/sqlx-postgres/src/codec/backend/ready_for_query.rs @@ -1,11 +1,9 @@ use bytes::Bytes; - -use crate::error::Error; -use crate::io::Decode; +use sqlx_core::{error::Error, io::Decode}; #[derive(Debug)] #[repr(u8)] -pub enum TransactionStatus { +pub(crate) enum TransactionStatus { /// Not in a transaction block. Idle = b'I', @@ -17,8 +15,8 @@ pub enum TransactionStatus { } #[derive(Debug)] -pub struct ReadyForQuery { - pub transaction_status: TransactionStatus, +pub(crate) struct ReadyForQuery { + pub(crate) transaction_status: TransactionStatus, } impl Decode<'_> for ReadyForQuery { @@ -29,10 +27,10 @@ impl Decode<'_> for ReadyForQuery { b'E' => TransactionStatus::Error, status => { - return Err(err_protocol!( + return Err(Error::protocol_msg(format!( "unknown transaction status: {:?}", status as char - )); + ))); } }; @@ -42,13 +40,18 @@ impl Decode<'_> for ReadyForQuery { } } -#[test] -fn test_decode_ready_for_query() -> Result<(), Error> { - const DATA: &[u8] = b"E"; +#[cfg(test)] +mod tests { + use super::*; - let m = ReadyForQuery::decode(Bytes::from_static(DATA))?; + #[test] + fn decode() -> Result<(), Error> { + const DATA: &[u8] = b"E"; - assert!(matches!(m.transaction_status, TransactionStatus::Error)); + let m = ReadyForQuery::decode(Bytes::from_static(DATA))?; - Ok(()) + assert!(matches!(m.transaction_status, TransactionStatus::Error)); + + Ok(()) + } } diff --git a/sqlx-postgres/src/codec/frontend/mod.rs b/sqlx-postgres/src/codec/frontend/mod.rs new file mode 100644 index 000000000..51559a2d1 --- /dev/null +++ b/sqlx-postgres/src/codec/frontend/mod.rs @@ -0,0 +1,17 @@ +// mod bind; +// mod close; +// mod describe; +// mod execute; +// mod flush; +// mod parse; +mod password; +mod query; +mod startup; +mod sync; +mod terminate; + +pub(crate) use password::Password; +pub(crate) use query::Query; +pub(crate) use startup::Startup; +pub(crate) use sync::Sync; +pub(crate) use terminate::Terminate; diff --git a/sqlx-core/src/postgres/message/password.rs b/sqlx-postgres/src/codec/frontend/password.rs similarity index 55% rename from sqlx-core/src/postgres/message/password.rs rename to sqlx-postgres/src/codec/frontend/password.rs index 8b0a8d66a..a97b3c195 100644 --- a/sqlx-core/src/postgres/message/password.rs +++ b/sqlx-postgres/src/codec/frontend/password.rs @@ -1,12 +1,10 @@ +use crate::io::{put_length_prefixed, put_str}; +use md5::{Digest, Md5}; +use sqlx_core::{error::Error, io::Encode}; use std::fmt::Write; -use md5::{Digest, Md5}; - -use crate::io::{BufMutExt, Encode}; -use crate::postgres::io::PgBufMutExt; - #[derive(Debug)] -pub enum Password<'a> { +pub(crate) enum Password<'a> { Cleartext(&'a str), Md5 { @@ -27,14 +25,14 @@ impl Password<'_> { } impl Encode<'_> for Password<'_> { - fn encode_with(&self, buf: &mut Vec, _: ()) { + fn encode_with(&self, buf: &mut Vec, _: ()) -> Result<(), Error> { buf.reserve(1 + 4 + self.len()); buf.push(b'p'); - buf.put_length_prefixed(|buf| { + put_length_prefixed(buf, |buf| { match self { Password::Cleartext(password) => { - buf.put_str_nul(password); + put_str(buf, password); } Password::Md5 { @@ -63,70 +61,68 @@ impl Encode<'_> for Password<'_> { let _ = write!(output, "md5{:x}", hasher.finalize()); - buf.put_str_nul(&output); + put_str(buf, &output); } } - }); + + Ok(()) + }) } } -#[test] -fn test_encode_clear_password() { - const EXPECTED: &[u8] = b"p\0\0\0\rpassword\0"; +#[cfg(test)] +mod tests { + use super::*; - let mut buf = Vec::new(); - let m = Password::Cleartext("password"); + #[test] + fn encode_md5() { + const EXPECTED: &[u8] = b"p\0\0\0(md53e2c9d99d49b201ef867a36f3f9ed62c\0"; - m.encode(&mut buf); - - assert_eq!(buf, EXPECTED); -} - -#[test] -fn test_encode_md5_password() { - const EXPECTED: &[u8] = b"p\0\0\0(md53e2c9d99d49b201ef867a36f3f9ed62c\0"; - - let mut buf = Vec::new(); - let m = Password::Md5 { - password: "password", - username: "root", - salt: [147, 24, 57, 152], - }; - - m.encode(&mut buf); - - assert_eq!(buf, EXPECTED); -} - -#[cfg(all(test, not(debug_assertions)))] -#[bench] -fn bench_encode_clear_password(b: &mut test::Bencher) { - use test::black_box; - - let mut buf = Vec::with_capacity(128); - - b.iter(|| { - buf.clear(); - - black_box(Password::Cleartext("password")).encode(&mut buf); - }); -} - -#[cfg(all(test, not(debug_assertions)))] -#[bench] -fn bench_encode_md5_password(b: &mut test::Bencher) { - use test::black_box; - - let mut buf = Vec::with_capacity(128); - - b.iter(|| { - buf.clear(); - - black_box(Password::Md5 { + let mut buf = Vec::new(); + let m = Password::Md5 { password: "password", username: "root", salt: [147, 24, 57, 152], - }) - .encode(&mut buf); - }); + }; + + m.encode(&mut buf); + + assert_eq!(buf, EXPECTED); + } +} + +#[cfg(all(test, not(debug_assertions)))] +mod bench { + use super::*; + + #[bench] + fn encode_clear(b: &mut test::Bencher) { + use test::black_box; + + let mut buf = Vec::with_capacity(128); + + b.iter(|| { + buf.clear(); + + black_box(Password::Cleartext("password")).encode(&mut buf); + }); + } + + #[bench] + fn encode_md5(b: &mut test::Bencher) { + use test::black_box; + + let mut buf = Vec::with_capacity(128); + + b.iter(|| { + buf.clear(); + + black_box(Password::Md5 { + password: "password", + username: "root", + salt: [147, 24, 57, 152], + }) + .encode(&mut buf); + }); + } } diff --git a/sqlx-postgres/src/codec/frontend/query.rs b/sqlx-postgres/src/codec/frontend/query.rs new file mode 100644 index 000000000..ca4d05ebf --- /dev/null +++ b/sqlx-postgres/src/codec/frontend/query.rs @@ -0,0 +1,43 @@ +use crate::io::put_str; +use sqlx_core::error::Error; +use sqlx_core::io::Encode; + +/// A simple query cycle is initiated by the frontend sending a `Query` message to the backend. +/// The message includes an SQL command (or commands) expressed as a text string. +#[derive(Debug)] +pub(crate) struct Query<'a>(pub(crate) &'a str); + +impl Encode<'_> for Query<'_> { + fn encode_with(&self, buf: &mut Vec, _: ()) -> Result<(), Error> { + let len = 4 + self.0.len() + 1; + + if len + 1 > i32::MAX as usize { + return Err(Error::Query( + "SQL query string is too large to transmit".into(), + )); + } + + buf.reserve(len + 1); + buf.push(b'Q'); + buf.extend(&(len as i32).to_be_bytes()); + put_str(buf, self.0); + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn encode() { + const EXPECTED: &[u8] = b"Q\0\0\0\rSELECT 1\0"; + + let mut buf = Vec::new(); + + Query("SELECT 1").encode(&mut buf); + + assert_eq!(buf, EXPECTED); + } +} diff --git a/sqlx-postgres/src/codec/frontend/startup.rs b/sqlx-postgres/src/codec/frontend/startup.rs new file mode 100644 index 000000000..a4436e6df --- /dev/null +++ b/sqlx-postgres/src/codec/frontend/startup.rs @@ -0,0 +1,87 @@ +use crate::io::{put_length_prefixed, put_str}; +use sqlx_core::{error::Error, io::Encode}; + +// To begin a session, a frontend opens a connection to the server and sends a startup message. +// This message includes the names of the user and of the database the user wants to connect to; +// it also identifies the particular protocol version to be used. + +// Optionally, the startup message can include additional settings for run-time parameters. + +pub(crate) struct Startup<'a>(pub(crate) &'a [(&'a str, Option<&'a str>)]); + +impl Encode<'_> for Startup<'_> { + fn encode_with(&self, buf: &mut Vec, _: ()) -> Result<(), Error> { + put_length_prefixed(buf, |buf| { + // The protocol version number. + // + // The most significant 16 bits are the major version + // number (3 for the protocol described here). + // + // The least significant 16 bits are the minor version + // number (0 for the protocol described here). + buf.extend_from_slice(&0x0003_0000_i32.to_be_bytes()); + + for (name, value) in self.0 { + if let Some(value) = value { + put_startup_parameter(buf, name, value); + } + } + + // A zero byte is required as a terminator + // after the last name/value pair. + buf.push(0); + + Ok(()) + }) + } +} + +fn put_startup_parameter(buf: &mut Vec, name: &str, value: &str) { + put_str(buf, name); + put_str(buf, value); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn encode() { + const EXPECTED: &[u8] = b"\0\0\0)\0\x03\0\0user\0postgres\0database\0postgres\0\0"; + + let mut buf = Vec::new(); + + let m = Startup(&[("user", Some("postgres")), ("database", Some("postgres"))]); + + m.encode(&mut buf); + + assert_eq!(buf, EXPECTED); + } +} + +#[cfg(all(test, not(debug_assertions)))] +mod bench { + use super::*; + + #[bench] + fn encode(b: &mut test::Bencher) { + use test::black_box; + + let mut buf = Vec::with_capacity(1024); + + b.iter(|| { + buf.clear(); + + let m = (Startup(&[ + ("user", "postgres"), + ("database", "postgres"), + ("DateStyle", "ISO, MDY"), + ("client_encoding", "UTF8"), + ("TimeZone", "UTC"), + ("extra_float_digits", "3"), + ])); + + m.encode(&mut buf); + }); + } +} diff --git a/sqlx-postgres/src/codec/frontend/sync.rs b/sqlx-postgres/src/codec/frontend/sync.rs new file mode 100644 index 000000000..4279b6dd5 --- /dev/null +++ b/sqlx-postgres/src/codec/frontend/sync.rs @@ -0,0 +1,24 @@ +use sqlx_core::error::Error; +use sqlx_core::io::Encode; + +/// At completion of each series of extended-query messages, the frontend should issue a +/// `Sync` message. +/// +/// This parameterless message causes the backend to close the current transaction if +/// it's not inside a `BEGIN` / `COMMIT` transaction block (“close” meaning to commit +/// if no error, or roll back if error). Then a `ReadyForQuery` response is issued. +/// +/// The purpose of Sync is to provide a resynchronization point for error recovery. +/// +#[derive(Debug)] +pub(crate) struct Sync; + +impl Encode<'_> for Sync { + fn encode_with(&self, buf: &mut Vec, _: ()) -> Result<(), Error> { + buf.reserve(5); + buf.push(b'S'); + buf.extend(&4_i32.to_be_bytes()); + + Ok(()) + } +} diff --git a/sqlx-postgres/src/codec/frontend/terminate.rs b/sqlx-postgres/src/codec/frontend/terminate.rs new file mode 100644 index 000000000..bb507c8e5 --- /dev/null +++ b/sqlx-postgres/src/codec/frontend/terminate.rs @@ -0,0 +1,21 @@ +use sqlx_core::error::Error; +use sqlx_core::io::Encode; + +/// The normal, graceful termination procedure is that the frontend +/// sends a Terminate message and immediately closes the connection. +/// +/// On receipt of this message, the backend closes the connection +/// and terminates. +/// +#[derive(Debug)] +pub(crate) struct Terminate; + +impl Encode<'_> for Terminate { + fn encode_with(&self, buf: &mut Vec, _: ()) -> Result<(), Error> { + buf.reserve(5); + buf.push(b'X'); + buf.extend(&4_u32.to_be_bytes()); + + Ok(()) + } +} diff --git a/sqlx-postgres/src/codec/mod.rs b/sqlx-postgres/src/codec/mod.rs new file mode 100644 index 000000000..9f6aebb7b --- /dev/null +++ b/sqlx-postgres/src/codec/mod.rs @@ -0,0 +1,2 @@ +pub(crate) mod backend; +pub(crate) mod frontend; diff --git a/sqlx-postgres/src/connection/connect.rs b/sqlx-postgres/src/connection/connect.rs new file mode 100644 index 000000000..6048ffa9e --- /dev/null +++ b/sqlx-postgres/src/connection/connect.rs @@ -0,0 +1,120 @@ +use crate::codec::backend::{Authentication, BackendKeyData, MessageFormat, ReadyForQuery}; +use crate::codec::frontend; +use crate::{PgConnectOptions, PgConnection}; +use sqlx_core::{error::Error, io::BufStream}; +use sqlx_rt::TcpStream; + +impl PgConnection { + pub(crate) async fn connect(options: &PgConnectOptions) -> Result { + let stream = TcpStream::connect((&*options.host, options.port)).await?; + + // Set TCP_NODELAY to disable the Nagle algorithm + // We are telling the kernel that we bundle data to be sent in large write() calls + // instead of sending many small packets. + stream.set_nodelay(true)?; + + // TODO: Upgrade to TLS if asked + + let mut stream = BufStream::with_capacity(stream, 1024, 1024); + + // To begin a session, a frontend opens a connection to the server + // and sends a startup message. + + stream.write(frontend::Startup(&[ + ("user", Some(&options.username)), + ("database", options.database.as_deref()), + // Sets the display format for date and time values, + // as well as the rules for interpreting ambiguous date input values. + ("DateStyle", Some("ISO, MDY")), + // + // Sets the client-side encoding (character set). + // + ("client_encoding", Some("UTF8")), + // + // Sets the time zone for displaying and interpreting time stamps. + ("TimeZone", Some("UTC")), + // + // Adjust postgres to return (more) precise values for floats + // NOTE: This is default in postgres 12+ + ("extra_float_digits", Some("3")), + ]))?; + + // Wrap our network in the connection type with default values for its properties + // This lets us access methods on self + + let mut conn = Self::new(stream); + + // The server then uses this information and the contents of + // its configuration files (such as pg_hba.conf) to determine whether the connection is + // provisionally acceptable, and what additional + // authentication is required (if any). + + loop { + let message = conn.recv().await?; + + match message.format { + MessageFormat::Authentication => match message.decode()? { + Authentication::Ok => { + // the authentication exchange is successfully completed + // do nothing; no more information is required to continue + } + + Authentication::CleartextPassword => { + // The frontend must now send a [PasswordMessage] containing the + // password in clear-text form. + + conn.stream.write(frontend::Password::Cleartext( + options.password.as_deref().unwrap_or_default(), + ))?; + } + + Authentication::Md5Password(body) => { + // The frontend must now send a [PasswordMessage] containing the + // password (with user name) encrypted via MD5, then encrypted again + // using the 4-byte random salt specified in the + // [AuthenticationMD5Password] message. + + conn.stream.write(frontend::Password::Md5 { + username: &options.username, + password: options.password.as_deref().unwrap_or_default(), + salt: body.salt, + })?; + } + + // Authentication::Sasl(body) => { + // // sasl::authenticate(&mut stream, options, body).await?; + // todo!("sasl") + // } + method => { + return Err(Error::protocol_msg(format!( + "unsupported authentication method: {:?}", + method + ))); + } + }, + + MessageFormat::BackendKeyData => { + // provides secret-key data that the frontend must save if it wants to be + // able to issue cancel requests later + + let data: BackendKeyData = message.decode()?; + + conn.process_id = data.process_id; + conn.secret_key = data.secret_key; + } + + MessageFormat::ReadyForQuery => { + conn.transaction_status = message.decode::()?.transaction_status; + + // start-up is completed. + // the frontend can now issue commands. + break; + } + + _ => {} + } + } + + Ok(conn) + } +} diff --git a/sqlx-postgres/src/connection/io.rs b/sqlx-postgres/src/connection/io.rs new file mode 100644 index 000000000..d71b3e983 --- /dev/null +++ b/sqlx-postgres/src/connection/io.rs @@ -0,0 +1,61 @@ +use crate::codec::backend::{MessageFormat, RawMessage}; +use crate::PgConnection; +use bytes::{Buf, Bytes}; +use sqlx_core::error::Error; + +impl PgConnection { + /// Wait for the next message from the database server. + /// Handles standard and asynchronous messages. + pub(crate) async fn recv(&mut self) -> Result { + loop { + let message = self.recv_unchecked().await?; + + match message.format { + MessageFormat::ErrorResponse => { + // an error was returned from the database + todo!("errors: {:?}", message.contents) + } + + MessageFormat::NotificationResponse => { + // a notification was received; this connection has had `LISTEN` ran on it + todo!("notifications"); + continue; + } + + MessageFormat::ParameterStatus => { + // informs the frontend about the current + // setting of backend parameters + + // we currently have no use for that data so we ignore this message + continue; + } + + _ => {} + } + + return Ok(message); + } + } + + /// Wait for the next message from the database server. + pub(crate) async fn recv_unchecked(&mut self) -> Result { + let mut header = self.stream.peek(0, 5).await?; + // if the future for this method is dropped now, we will re-peek the same header + + // the first byte of a message identifies the message type + let kind = header.get_u8(); + + // and the next four bytes specify the length of the rest of the message ( + // this length count includes itself, but not the message-type byte). + let length = header.get_i32() as usize - 4; + + let contents = self.stream.read(5, length).await?; + // now the packet is fully consumed from the stream and when this method is called + // again, it will get the *next* message + + Ok(RawMessage { + format: MessageFormat::try_from_u8(kind)?, + contents, + }) + } +} diff --git a/sqlx-postgres/src/connection/mod.rs b/sqlx-postgres/src/connection/mod.rs new file mode 100644 index 000000000..1516405f2 --- /dev/null +++ b/sqlx-postgres/src/connection/mod.rs @@ -0,0 +1,67 @@ +use crate::codec::backend::TransactionStatus; +use crate::{PgConnectOptions, Postgres}; +use futures_core::future::BoxFuture; +use sqlx_core::connection::Connection; +use sqlx_core::error::Error; +use sqlx_core::io::BufStream; +use sqlx_rt::TcpStream; + +mod connect; +mod io; + +/// A connection to a PostgreSQL database. +pub struct PgConnection { + // underlying TCP or UDS stream, + // wrapped in a potentially TLS stream, + // wrapped in a buffered stream + stream: BufStream, + + // process id of this backend + // used to send cancel requests + #[allow(dead_code)] + process_id: u32, + + // secret key of this backend + // used to send cancel requests + #[allow(dead_code)] + secret_key: u32, + + // status of the connection + // are we in a transaction? + transaction_status: TransactionStatus, +} + +impl PgConnection { + pub(crate) const fn new(stream: BufStream) -> Self { + Self { + stream, + process_id: 0, + secret_key: 0, + transaction_status: TransactionStatus::Idle, + } + } +} + +impl Connection for PgConnection { + type Database = Postgres; + + type Options = PgConnectOptions; + + fn close(self) -> BoxFuture<'static, Result<(), Error>> { + unimplemented!() + } + + fn ping(&mut self) -> BoxFuture<'_, Result<(), Error>> { + unimplemented!() + } + + #[doc(hidden)] + fn flush(&mut self) -> BoxFuture<'_, Result<(), Error>> { + unimplemented!() + } + + #[doc(hidden)] + fn should_flush(&self) -> bool { + unimplemented!() + } +} diff --git a/sqlx-postgres/src/database.rs b/sqlx-postgres/src/database.rs new file mode 100644 index 000000000..46f0bdc11 --- /dev/null +++ b/sqlx-postgres/src/database.rs @@ -0,0 +1,44 @@ +use crate::PgConnection; +use sqlx_core::database::{Database, HasStatementCache}; + +/// PostgreSQL database driver. +#[derive(Debug)] +pub struct Postgres; + +impl Database for Postgres { + type Connection = PgConnection; + + // type TransactionManager = PgTransactionManager; + // + // type Row = PgRow; + // + // type Done = PgDone; + // + // type Column = PgColumn; + // + // type TypeInfo = PgTypeInfo; + // + // type Value = PgValue; +} + +// impl<'r> HasValueRef<'r> for Postgres { +// type Database = Postgres; +// +// type ValueRef = PgValueRef<'r>; +// } +// +// impl HasArguments<'_> for Postgres { +// type Database = Postgres; +// +// type Arguments = PgArguments; +// +// type ArgumentBuffer = PgArgumentBuffer; +// } +// +// impl<'q> HasStatement<'q> for Postgres { +// type Database = Postgres; +// +// type Statement = PgStatement<'q>; +// } + +impl HasStatementCache for Postgres {} diff --git a/sqlx-postgres/src/io.rs b/sqlx-postgres/src/io.rs new file mode 100644 index 000000000..37a21ae83 --- /dev/null +++ b/sqlx-postgres/src/io.rs @@ -0,0 +1,22 @@ +use sqlx_core::error::Error; + +pub(crate) fn put_length_prefixed( + buf: &mut Vec, + f: impl FnOnce(&mut Vec) -> Result, +) -> Result { + let offset = buf.len(); + buf.resize(offset + 4, 0); + + let r = f(buf)?; + + let len = (buf.len() - offset) as i32; + (&mut buf[offset..offset + 4]).copy_from_slice(&len.to_be_bytes()); + + Ok(r) +} + +#[inline] +pub(crate) fn put_str(buf: &mut Vec, s: &str) { + buf.extend_from_slice(s.as_bytes()); + buf.push(b'\0'); +} diff --git a/sqlx-postgres/src/lib.rs b/sqlx-postgres/src/lib.rs new file mode 100644 index 000000000..e59999f77 --- /dev/null +++ b/sqlx-postgres/src/lib.rs @@ -0,0 +1,21 @@ +//! **PostgreSQL** database driver. +//! +#![forbid(unsafe_code)] +#![warn( + future_incompatible, + rust_2018_idioms, + missing_docs, + missing_doc_code_examples, + unreachable_pub +)] +#![allow(unused)] + +mod codec; +mod connection; +mod database; +mod io; +mod options; + +pub use connection::PgConnection; +pub use database::Postgres; +pub use options::{PgConnectOptions, PgSslMode}; diff --git a/sqlx-postgres/src/options/connect.rs b/sqlx-postgres/src/options/connect.rs new file mode 100644 index 000000000..06747be99 --- /dev/null +++ b/sqlx-postgres/src/options/connect.rs @@ -0,0 +1,13 @@ +use crate::{PgConnectOptions, PgConnection}; +use futures_core::future::BoxFuture; +use sqlx_core::error::Error; +use sqlx_core::options::ConnectOptions; + +impl ConnectOptions for PgConnectOptions { + type Connection = PgConnection; + + #[inline] + fn connect(&self) -> BoxFuture<'_, Result> { + Box::pin(PgConnection::connect(self)) + } +} diff --git a/sqlx-core/src/postgres/options/mod.rs b/sqlx-postgres/src/options/mod.rs similarity index 93% rename from sqlx-core/src/postgres/options/mod.rs rename to sqlx-postgres/src/options/mod.rs index 65ed73ef7..a985ed579 100644 --- a/sqlx-core/src/postgres/options/mod.rs +++ b/sqlx-postgres/src/options/mod.rs @@ -5,6 +5,8 @@ mod connect; mod parse; mod ssl_mode; +// false positive +#[allow(unreachable_pub)] pub use ssl_mode::PgSslMode; /// Options and flags which can be used to configure a PostgreSQL connection. @@ -43,8 +45,9 @@ pub use ssl_mode::PgSslMode; /// /// ```rust,no_run /// # use sqlx_core::error::Error; -/// # use sqlx_core::connection::{Connection, ConnectOptions}; -/// # use sqlx_core::postgres::{PgConnectOptions, PgConnection, PgSslMode}; +/// # use sqlx_core::connection::Connection; +/// # use sqlx_core::options::ConnectOptions; +/// # use sqlx_postgres::{PgConnectOptions, PgConnection, PgSslMode}; /// # /// # fn main() { /// # #[cfg(feature = "runtime-async-std")] @@ -102,7 +105,7 @@ impl PgConnectOptions { /// # Example /// /// ```rust - /// # use sqlx_core::postgres::PgConnectOptions; + /// # use sqlx_postgres::PgConnectOptions; /// let options = PgConnectOptions::new(); /// ``` pub fn new() -> Self { @@ -141,7 +144,7 @@ impl PgConnectOptions { /// # Example /// /// ```rust - /// # use sqlx_core::postgres::PgConnectOptions; + /// # use sqlx_postgres::PgConnectOptions; /// let options = PgConnectOptions::new() /// .host("localhost"); /// ``` @@ -157,7 +160,7 @@ impl PgConnectOptions { /// # Example /// /// ```rust - /// # use sqlx_core::postgres::PgConnectOptions; + /// # use sqlx_postgres::PgConnectOptions; /// let options = PgConnectOptions::new() /// .port(5432); /// ``` @@ -183,7 +186,7 @@ impl PgConnectOptions { /// # Example /// /// ```rust - /// # use sqlx_core::postgres::PgConnectOptions; + /// # use sqlx_postgres::PgConnectOptions; /// let options = PgConnectOptions::new() /// .username("postgres"); /// ``` @@ -197,7 +200,7 @@ impl PgConnectOptions { /// # Example /// /// ```rust - /// # use sqlx_core::postgres::PgConnectOptions; + /// # use sqlx_postgres::PgConnectOptions; /// let options = PgConnectOptions::new() /// .username("root") /// .password("safe-and-secure"); @@ -212,7 +215,7 @@ impl PgConnectOptions { /// # Example /// /// ```rust - /// # use sqlx_core::postgres::PgConnectOptions; + /// # use sqlx_postgres::PgConnectOptions; /// let options = PgConnectOptions::new() /// .database("postgres"); /// ``` @@ -232,7 +235,7 @@ impl PgConnectOptions { /// # Example /// /// ```rust - /// # use sqlx_core::postgres::{PgSslMode, PgConnectOptions}; + /// # use sqlx_postgres::{PgSslMode, PgConnectOptions}; /// let options = PgConnectOptions::new() /// .ssl_mode(PgSslMode::Require); /// ``` @@ -248,7 +251,7 @@ impl PgConnectOptions { /// # Example /// /// ```rust - /// # use sqlx_core::postgres::{PgSslMode, PgConnectOptions}; + /// # use sqlx_postgres::{PgSslMode, PgConnectOptions}; /// let options = PgConnectOptions::new() /// // Providing a CA certificate with less than VerifyCa is pointless /// .ssl_mode(PgSslMode::VerifyCa) diff --git a/sqlx-core/src/postgres/options/parse.rs b/sqlx-postgres/src/options/parse.rs similarity index 79% rename from sqlx-core/src/postgres/options/parse.rs rename to sqlx-postgres/src/options/parse.rs index e5dd27fde..a55888c7e 100644 --- a/sqlx-core/src/postgres/options/parse.rs +++ b/sqlx-postgres/src/options/parse.rs @@ -1,5 +1,5 @@ -use crate::error::Error; -use crate::postgres::PgConnectOptions; +use crate::PgConnectOptions; +use sqlx_core::error::Error; use std::str::FromStr; use url::Url; @@ -7,7 +7,14 @@ impl FromStr for PgConnectOptions { type Err = Error; fn from_str(s: &str) -> Result { - let url: Url = s.parse().map_err(Error::config)?; + let url: Url = s.parse().map_err(Error::configuration)?; + + if !matches!(url.scheme(), "postgres" | "postgresql") { + return Err(Error::configuration_msg(format!( + "unsupported URI scheme {:?} for PostgreSQL", + url.scheme() + ))); + } let mut options = Self::default(); @@ -36,7 +43,7 @@ impl FromStr for PgConnectOptions { for (key, value) in url.query_pairs().into_iter() { match &*key { "sslmode" | "ssl-mode" => { - options = options.ssl_mode(value.parse().map_err(Error::config)?); + options = options.ssl_mode(value.parse().map_err(Error::configuration)?); } "sslrootcert" | "ssl-root-cert" | "ssl-ca" => { @@ -44,8 +51,8 @@ impl FromStr for PgConnectOptions { } "statement-cache-capacity" => { - options = - options.statement_cache_capacity(value.parse().map_err(Error::config)?); + options = options + .statement_cache_capacity(value.parse().map_err(Error::configuration)?); } "host" => { diff --git a/sqlx-core/src/postgres/options/ssl_mode.rs b/sqlx-postgres/src/options/ssl_mode.rs similarity index 98% rename from sqlx-core/src/postgres/options/ssl_mode.rs rename to sqlx-postgres/src/options/ssl_mode.rs index fe2a9b614..2d3616c81 100644 --- a/sqlx-core/src/postgres/options/ssl_mode.rs +++ b/sqlx-postgres/src/options/ssl_mode.rs @@ -1,4 +1,4 @@ -use crate::error::Error; +use sqlx_core::error::Error; use std::str::FromStr; /// Options for controlling the level of protection provided for PostgreSQL SSL connections.