From a374c18a18bb74215a01b939d0c93962e2122589 Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Wed, 19 Feb 2020 08:10:27 -0800 Subject: [PATCH] postgres: rewrite protocol in more iterative and lazy fashion --- Cargo.lock | 9 + Cargo.toml | 1 + examples/postgres/basic/Cargo.toml | 10 + examples/postgres/basic/src/main.rs | 25 + examples/queries/account-by-id.sql | 2 - sqlx-core/src/arguments.rs | 4 +- sqlx-core/src/connection.rs | 143 +++++- sqlx-core/src/cursor.rs | 60 ++- sqlx-core/src/database.rs | 12 +- sqlx-core/src/decode.rs | 5 +- sqlx-core/src/encode.rs | 8 +- sqlx-core/src/executor.rs | 16 +- sqlx-core/src/io/buf_stream.rs | 14 +- sqlx-core/src/lib.rs | 8 +- sqlx-core/src/mysql/arguments.rs | 6 +- sqlx-core/src/mysql/row.rs | 8 +- sqlx-core/src/mysql/types/bool.rs | 4 +- sqlx-core/src/mysql/types/bytes.rs | 8 +- sqlx-core/src/mysql/types/chrono.rs | 10 +- sqlx-core/src/mysql/types/float.rs | 6 +- sqlx-core/src/mysql/types/int.rs | 10 +- sqlx-core/src/mysql/types/str.rs | 8 +- sqlx-core/src/mysql/types/uint.rs | 10 +- sqlx-core/src/pool/conn.rs | 20 +- sqlx-core/src/pool/executor.rs | 91 ++-- sqlx-core/src/pool/mod.rs | 6 +- sqlx-core/src/pool/options.rs | 13 +- sqlx-core/src/postgres/arguments.rs | 7 +- sqlx-core/src/postgres/connection.rs | 464 +++++++++--------- sqlx-core/src/postgres/cursor.rs | 339 +++++++++++-- sqlx-core/src/postgres/database.rs | 8 +- sqlx-core/src/postgres/error.rs | 2 +- sqlx-core/src/postgres/executor.rs | 90 ++-- sqlx-core/src/postgres/mod.rs | 3 +- .../src/postgres/protocol/authentication.rs | 266 +++++----- .../src/postgres/protocol/command_complete.rs | 20 +- sqlx-core/src/postgres/protocol/data_row.rs | 80 ++- sqlx-core/src/postgres/protocol/message.rs | 57 ++- sqlx-core/src/postgres/protocol/mod.rs | 5 +- .../protocol/parameter_description.rs | 12 +- sqlx-core/src/postgres/protocol/response.rs | 4 +- .../src/postgres/protocol/row_description.rs | 12 +- sqlx-core/src/postgres/row.rs | 64 +-- sqlx-core/src/postgres/sasl.rs | 29 +- sqlx-core/src/postgres/stream.rs | 90 ++++ sqlx-core/src/postgres/types/bool.rs | 6 +- sqlx-core/src/postgres/types/bytes.rs | 10 +- sqlx-core/src/postgres/types/chrono.rs | 18 +- sqlx-core/src/postgres/types/float.rs | 10 +- sqlx-core/src/postgres/types/int.rs | 14 +- sqlx-core/src/postgres/types/str.rs | 10 +- sqlx-core/src/postgres/types/uuid.rs | 6 +- sqlx-core/src/query.rs | 47 +- sqlx-core/src/row.rs | 43 +- sqlx-core/src/transaction.rs | 43 +- sqlx-core/src/types.rs | 25 +- sqlx-macros/src/database/mod.rs | 4 +- src/lib.rs | 2 +- tests/postgres-types.rs | 118 ++--- tests/postgres.rs | 92 ++-- 60 files changed, 1586 insertions(+), 931 deletions(-) create mode 100644 examples/postgres/basic/Cargo.toml create mode 100644 examples/postgres/basic/src/main.rs delete mode 100644 examples/queries/account-by-id.sql create mode 100644 sqlx-core/src/postgres/stream.rs diff --git a/Cargo.lock b/Cargo.lock index 83cf1a52..017f02f1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1526,6 +1526,15 @@ dependencies = [ "uuid 0.8.1 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "sqlx-example-postgres-basic" +version = "0.1.0" +dependencies = [ + "anyhow 1.0.26 (registry+https://github.com/rust-lang/crates.io-index)", + "async-std 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)", + "sqlx 0.2.5", +] + [[package]] name = "sqlx-example-realworld-postgres" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 885e8832..7f6cdc9b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,6 +3,7 @@ members = [ ".", "sqlx-core", "sqlx-macros", + "examples/postgres/basic", "examples/realworld-postgres", "examples/todos-postgres", ] diff --git a/examples/postgres/basic/Cargo.toml b/examples/postgres/basic/Cargo.toml new file mode 100644 index 00000000..6a207e14 --- /dev/null +++ b/examples/postgres/basic/Cargo.toml @@ -0,0 +1,10 @@ +[package] +workspace = "../../.." +name = "sqlx-example-postgres-basic" +version = "0.1.0" +edition = "2018" + +[dependencies] +async-std = { version = "1", features = [ "attributes" ] } +anyhow = "1" +sqlx = { path = "../../..", features = [ "postgres" ] } diff --git a/examples/postgres/basic/src/main.rs b/examples/postgres/basic/src/main.rs new file mode 100644 index 00000000..5747c9cb --- /dev/null +++ b/examples/postgres/basic/src/main.rs @@ -0,0 +1,25 @@ +use sqlx::{Connect, Connection, Cursor, Executor, PgConnection, Row}; +use std::convert::TryInto; +use std::time::Instant; + +#[async_std::main] +async fn main() -> anyhow::Result<()> { + let mut conn = PgConnection::connect("postgres://").await?; + + let mut rows = sqlx::query("SELECT definition FROM pg_database") + .execute(&mut conn) + .await?; + + // let start = Instant::now(); + // while let Some(row) = cursor.next().await? { + // // let raw = row.try_get(0)?.unwrap(); + // + // // println!("hai: {:?}", raw); + // } + + println!("?? = {}", rows); + + // conn.close().await?; + + Ok(()) +} diff --git a/examples/queries/account-by-id.sql b/examples/queries/account-by-id.sql deleted file mode 100644 index 98623eb8..00000000 --- a/examples/queries/account-by-id.sql +++ /dev/null @@ -1,2 +0,0 @@ -select * from (select (1) as id, 'Herp Derpinson' as name) accounts -where id = ? diff --git a/sqlx-core/src/arguments.rs b/sqlx-core/src/arguments.rs index 6519aba3..34c3403d 100644 --- a/sqlx-core/src/arguments.rs +++ b/sqlx-core/src/arguments.rs @@ -2,7 +2,7 @@ use crate::database::Database; use crate::encode::Encode; -use crate::types::HasSqlType; +use crate::types::Type; /// A tuple of arguments to be sent to the database. pub trait Arguments: Send + Sized + Default + 'static { @@ -15,7 +15,7 @@ pub trait Arguments: Send + Sized + Default + 'static { /// Add the value to the end of the arguments. fn add(&mut self, value: T) where - Self::Database: HasSqlType, + T: Type, T: Encode; } diff --git a/sqlx-core/src/connection.rs b/sqlx-core/src/connection.rs index e1249d0d..92bc13da 100644 --- a/sqlx-core/src/connection.rs +++ b/sqlx-core/src/connection.rs @@ -1,10 +1,14 @@ +use std::convert::TryInto; +use std::ops::{Deref, DerefMut}; + +use futures_core::future::BoxFuture; +use futures_util::TryFutureExt; + use crate::database::Database; use crate::describe::Describe; use crate::executor::Executor; +use crate::pool::{Pool, PoolConnection}; use crate::url::Url; -use futures_core::future::BoxFuture; -use futures_util::TryFutureExt; -use std::convert::TryInto; /// Represents a single database connection rather than a pool of database connections. /// @@ -20,20 +24,13 @@ where fn close(self) -> BoxFuture<'static, crate::Result<()>>; /// Verifies a connection to the database is still alive. - fn ping(&mut self) -> BoxFuture> - where - for<'a> &'a mut Self: Executor<'a>, - { - Box::pin((&mut *self).execute("SELECT 1").map_ok(|_| ())) - } + fn ping(&mut self) -> BoxFuture>; #[doc(hidden)] fn describe<'e, 'q: 'e>( &'e mut self, query: &'q str, - ) -> BoxFuture<'e, crate::Result>> { - todo!("make this a required function"); - } + ) -> BoxFuture<'e, crate::Result>>; } /// Represents a type that can directly establish a new connection. @@ -44,3 +41,125 @@ pub trait Connect: Connection { T: TryInto, Self: Sized; } + +mod internal { + pub enum MaybeOwnedConnection<'c, C> + where + C: super::Connect, + { + Borrowed(&'c mut C), + Owned(super::PoolConnection), + } + + pub enum ConnectionSource<'c, C> + where + C: super::Connect, + { + Empty, + Connection(MaybeOwnedConnection<'c, C>), + Pool(super::Pool), + } +} + +pub(crate) use self::internal::{ConnectionSource, MaybeOwnedConnection}; + +impl<'c, C> MaybeOwnedConnection<'c, C> +where + C: Connect, +{ + pub(crate) fn borrow(&mut self) -> MaybeOwnedConnection<'_, C> { + match self { + MaybeOwnedConnection::Borrowed(conn) => MaybeOwnedConnection::Borrowed(&mut *conn), + MaybeOwnedConnection::Owned(ref mut conn) => MaybeOwnedConnection::Borrowed(conn), + } + } +} + +impl<'c, C, DB> ConnectionSource<'c, C> +where + C: Connect, + DB: Database, +{ + pub(crate) async fn resolve_by_ref(&mut self) -> crate::Result> { + if let ConnectionSource::Pool(pool) = self { + *self = + ConnectionSource::Connection(MaybeOwnedConnection::Owned(pool.acquire().await?)); + } + + Ok(match self { + ConnectionSource::Empty => panic!("`PgCursor` must not be used after being polled"), + ConnectionSource::Connection(conn) => conn.borrow(), + ConnectionSource::Pool(_) => unreachable!(), + }) + } + + pub(crate) async fn resolve(mut self) -> crate::Result> { + if let ConnectionSource::Pool(pool) = self { + self = ConnectionSource::Connection(MaybeOwnedConnection::Owned(pool.acquire().await?)); + } + + Ok(self.into_connection()) + } + + pub(crate) fn into_connection(self) -> MaybeOwnedConnection<'c, C> { + match self { + ConnectionSource::Connection(conn) => conn, + ConnectionSource::Empty | ConnectionSource::Pool(_) => { + panic!("`PgCursor` must not be used after being polled"); + } + } + } +} + +impl Default for ConnectionSource<'_, C> +where + C: Connect, +{ + fn default() -> Self { + ConnectionSource::Empty + } +} + +impl<'c, C> From<&'c mut C> for MaybeOwnedConnection<'c, C> +where + C: Connect, +{ + fn from(conn: &'c mut C) -> Self { + MaybeOwnedConnection::Borrowed(conn) + } +} + +impl<'c, C> From> for MaybeOwnedConnection<'c, C> +where + C: Connect, +{ + fn from(conn: PoolConnection) -> Self { + MaybeOwnedConnection::Owned(conn) + } +} + +impl<'c, C> Deref for MaybeOwnedConnection<'c, C> +where + C: Connect, +{ + type Target = C; + + fn deref(&self) -> &Self::Target { + match self { + MaybeOwnedConnection::Borrowed(conn) => conn, + MaybeOwnedConnection::Owned(conn) => conn, + } + } +} + +impl<'c, C> DerefMut for MaybeOwnedConnection<'c, C> +where + C: Connect, +{ + fn deref_mut(&mut self) -> &mut Self::Target { + match self { + MaybeOwnedConnection::Borrowed(conn) => conn, + MaybeOwnedConnection::Owned(conn) => conn, + } + } +} diff --git a/sqlx-core/src/cursor.rs b/sqlx-core/src/cursor.rs index 048d41ff..626ff45f 100644 --- a/sqlx-core/src/cursor.rs +++ b/sqlx-core/src/cursor.rs @@ -3,7 +3,10 @@ use std::future::Future; use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; +use crate::connection::MaybeOwnedConnection; use crate::database::{Database, HasRow}; +use crate::executor::Execute; +use crate::{Connect, Pool}; /// Represents a result set, which is generated by executing a query against the database. /// @@ -13,7 +16,7 @@ use crate::database::{Database, HasRow}; /// Initially the `Cursor` is positioned before the first row. The `next` method moves the cursor /// to the next row, and because it returns `None` when there are no more rows, it can be used /// in a `while` loop to iterate through all returned rows. -pub trait Cursor<'a> +pub trait Cursor<'c, 'q> where Self: Send, // `.await`-ing a cursor will return the affected rows from the query @@ -21,16 +24,59 @@ where { type Database: Database; - /// Fetch the first row in the result. Returns `None` if no row is present. - /// - /// Returns `Error::MoreThanOneRow` if more than one row is in the result. - fn first(self) -> BoxFuture<'a, crate::Result::Row>>>; + // Construct the [Cursor] from a [Pool] + // Meant for internal use only + // TODO: Anyone have any better ideas on how to instantiate cursors generically from a pool? + #[doc(hidden)] + fn from_pool(pool: &Pool<::Connection>, query: E) -> Self + where + Self: Sized, + E: Execute<'q, Self::Database>; + + #[doc(hidden)] + fn from_connection(conn: C, query: E) -> Self + where + Self: Sized, + ::Connection: Connect, + // MaybeOwnedConnection<'c, ::Connection>: + // Connect, + C: Into::Connection>>, + E: Execute<'q, Self::Database>; + + #[doc(hidden)] + fn first(self) -> BoxFuture<'c, crate::Result>::Row>>> + where + 'q: 'c; /// Fetch the next row in the result. Returns `None` if there are no more rows. fn next(&mut self) -> BoxFuture::Row>>>; /// Map the `Row`s in this result to a different type, returning a [`Stream`] of the results. - fn map(self, f: F) -> BoxStream<'a, crate::Result> + fn map(self, f: F) -> BoxStream<'c, crate::Result> where - F: Fn(::Row) -> T; + F: MapRowFn, + T: 'c + Send + Unpin, + 'q: 'c; +} + +pub trait MapRowFn +where + Self: Send + Sync + 'static, + DB: Database, + DB: for<'c> HasRow<'c>, +{ + fn call(&self, row: ::Row) -> T; +} + +impl MapRowFn for F +where + DB: Database, + DB: for<'c> HasRow<'c>, + F: Send + Sync + 'static, + F: Fn(::Row) -> T, +{ + #[inline(always)] + fn call(&self, row: ::Row) -> T { + self(row) + } } diff --git a/sqlx-core/src/database.rs b/sqlx-core/src/database.rs index ef18f033..e8c274ee 100644 --- a/sqlx-core/src/database.rs +++ b/sqlx-core/src/database.rs @@ -13,9 +13,9 @@ use crate::types::TypeInfo; pub trait Database where Self: Sized + 'static, - Self: HasRow, + Self: for<'a> HasRow<'a, Database = Self>, Self: for<'a> HasRawValue<'a>, - Self: for<'a> HasCursor<'a, Database = Self>, + Self: for<'c, 'q> HasCursor<'c, 'q, Database = Self>, { /// The concrete `Connection` implementation for this database. type Connection: Connection; @@ -34,14 +34,14 @@ pub trait HasRawValue<'a> { type RawValue; } -pub trait HasCursor<'a> { +pub trait HasCursor<'c, 'q> { type Database: Database; - type Cursor: Cursor<'a, Database = Self::Database>; + type Cursor: Cursor<'c, 'q, Database = Self::Database>; } -pub trait HasRow { +pub trait HasRow<'a> { type Database: Database; - type Row: Row; + type Row: Row<'a, Database = Self::Database>; } diff --git a/sqlx-core/src/decode.rs b/sqlx-core/src/decode.rs index eaa166b9..81c79f00 100644 --- a/sqlx-core/src/decode.rs +++ b/sqlx-core/src/decode.rs @@ -4,7 +4,7 @@ use std::error::Error as StdError; use std::fmt::{self, Display}; use crate::database::Database; -use crate::types::HasSqlType; +use crate::types::Type; pub enum DecodeError { /// An unexpected `NULL` was encountered while decoding. @@ -40,7 +40,8 @@ where impl Decode for Option where - DB: Database + HasSqlType, + DB: Database, + T: Type, T: Decode, { fn decode(buf: &[u8]) -> Result { diff --git a/sqlx-core/src/encode.rs b/sqlx-core/src/encode.rs index 4be0c7bc..b7d38873 100644 --- a/sqlx-core/src/encode.rs +++ b/sqlx-core/src/encode.rs @@ -1,7 +1,7 @@ //! Types and traits for encoding values to the database. use crate::database::Database; -use crate::types::HasSqlType; +use crate::types::Type; use std::mem; /// The return type of [Encode::encode]. @@ -36,7 +36,8 @@ where impl Encode for &'_ T where - DB: Database + HasSqlType, + DB: Database, + T: Type, T: Encode, { fn encode(&self, buf: &mut Vec) { @@ -54,7 +55,8 @@ where impl Encode for Option where - DB: Database + HasSqlType, + DB: Database, + T: Type, T: Encode, { fn encode(&self, buf: &mut Vec) { diff --git a/sqlx-core/src/executor.rs b/sqlx-core/src/executor.rs index 43333764..8942c4f7 100644 --- a/sqlx-core/src/executor.rs +++ b/sqlx-core/src/executor.rs @@ -14,7 +14,7 @@ use futures_util::TryStreamExt; /// Implementations are provided for [`&Pool`](struct.Pool.html), /// [`&mut PoolConnection`](struct.PoolConnection.html), /// and [`&mut Connection`](trait.Connection.html). -pub trait Executor<'a> +pub trait Executor<'c> where Self: Send, { @@ -22,18 +22,18 @@ where type Database: Database; /// Executes a query that may or may not return a result set. - fn execute<'b, E>(self, query: E) -> >::Cursor + fn execute<'q, E>(self, query: E) -> >::Cursor where - E: Execute<'b, Self::Database>; + E: Execute<'q, Self::Database>; #[doc(hidden)] - fn execute_by_ref<'b, E>(&mut self, query: E) -> >::Cursor + fn execute_by_ref<'b, E>(&mut self, query: E) -> >::Cursor where E: Execute<'b, Self::Database>; } /// A type that may be executed against a database connection. -pub trait Execute<'a, DB> +pub trait Execute<'q, DB> where DB: Database, { @@ -43,15 +43,15 @@ where /// prepare the query. Returning `Some(Default::default())` is an empty arguments object that /// will be prepared (and cached) before execution. #[doc(hidden)] - fn into_parts(self) -> (&'a str, Option); + fn into_parts(self) -> (&'q str, Option); } -impl<'a, DB> Execute<'a, DB> for &'a str +impl<'q, DB> Execute<'q, DB> for &'q str where DB: Database, { #[inline] - fn into_parts(self) -> (&'a str, Option) { + fn into_parts(self) -> (&'q str, Option) { (self, None) } } diff --git a/sqlx-core/src/io/buf_stream.rs b/sqlx-core/src/io/buf_stream.rs index b96cf4df..85ca2f51 100644 --- a/sqlx-core/src/io/buf_stream.rs +++ b/sqlx-core/src/io/buf_stream.rs @@ -35,6 +35,11 @@ where } } + #[inline] + pub fn buffer(&self) -> &[u8] { + &self.rbuf[self.rbuf_rindex..] + } + #[inline] pub fn buffer_mut(&mut self) -> &mut Vec { &mut self.wbuf @@ -61,7 +66,14 @@ where self.rbuf_rindex += cnt; } - pub async fn peek(&mut self, cnt: usize) -> io::Result> { + pub async fn peek(&mut self, cnt: usize) -> io::Result<&[u8]> { + self.try_peek(cnt) + .await + .transpose() + .ok_or(io::ErrorKind::ConnectionAborted)? + } + + pub async fn try_peek(&mut self, cnt: usize) -> io::Result> { loop { // Reaching end-of-file (read 0 bytes) will continuously // return None from all future calls to read diff --git a/sqlx-core/src/lib.rs b/sqlx-core/src/lib.rs index 03dcc669..f6f3560e 100644 --- a/sqlx-core/src/lib.rs +++ b/sqlx-core/src/lib.rs @@ -1,5 +1,5 @@ -#![recursion_limit = "256"] #![forbid(unsafe_code)] +#![allow(unused)] #![cfg_attr(docsrs, feature(doc_cfg))] #[macro_use] @@ -52,7 +52,7 @@ pub use error::{Error, Result}; pub use connection::{Connect, Connection}; pub use cursor::Cursor; -pub use executor::Executor; +pub use executor::{Execute, Executor}; pub use query::{query, Query}; pub use transaction::Transaction; @@ -71,3 +71,7 @@ pub use mysql::MySql; #[cfg_attr(docsrs, doc(cfg(feature = "postgres")))] #[doc(inline)] pub use postgres::Postgres; + +// Named Lifetimes: +// 'c: connection +// 'q: query string (and arguments) diff --git a/sqlx-core/src/mysql/arguments.rs b/sqlx-core/src/mysql/arguments.rs index 9c6d8f45..f7e455a0 100644 --- a/sqlx-core/src/mysql/arguments.rs +++ b/sqlx-core/src/mysql/arguments.rs @@ -2,7 +2,7 @@ use crate::arguments::Arguments; use crate::encode::{Encode, IsNull}; use crate::mysql::types::MySqlTypeInfo; use crate::mysql::MySql; -use crate::types::HasSqlType; +use crate::types::Type; #[derive(Default)] pub struct MySqlArguments { @@ -27,10 +27,10 @@ impl Arguments for MySqlArguments { fn add(&mut self, value: T) where - Self::Database: HasSqlType, + Self::Database: Type, T: Encode, { - let type_id = >::type_info(); + let type_id = >::type_info(); let index = self.param_types.len(); self.param_types.push(type_id); diff --git a/sqlx-core/src/mysql/row.rs b/sqlx-core/src/mysql/row.rs index 45eeda4d..c21d4297 100644 --- a/sqlx-core/src/mysql/row.rs +++ b/sqlx-core/src/mysql/row.rs @@ -5,7 +5,7 @@ use crate::decode::Decode; use crate::mysql::protocol; use crate::mysql::MySql; use crate::row::{Row, RowIndex}; -use crate::types::HasSqlType; +use crate::types::Type; pub struct MySqlRow { pub(super) row: protocol::Row, @@ -21,7 +21,7 @@ impl Row for MySqlRow { fn get(&self, index: I) -> T where - Self::Database: HasSqlType, + Self::Database: Type, I: RowIndex, T: Decode, { @@ -32,7 +32,7 @@ impl Row for MySqlRow { impl RowIndex for usize { fn try_get(&self, row: &MySqlRow) -> crate::Result where - ::Database: HasSqlType, + ::Database: Type, T: Decode<::Database>, { Ok(Decode::decode_nullable(row.row.get(*self))?) @@ -42,7 +42,7 @@ impl RowIndex for usize { impl RowIndex for &'_ str { fn try_get(&self, row: &MySqlRow) -> crate::Result where - ::Database: HasSqlType, + ::Database: Type, T: Decode<::Database>, { let index = row diff --git a/sqlx-core/src/mysql/types/bool.rs b/sqlx-core/src/mysql/types/bool.rs index c1e6dd28..182a371a 100644 --- a/sqlx-core/src/mysql/types/bool.rs +++ b/sqlx-core/src/mysql/types/bool.rs @@ -3,9 +3,9 @@ use crate::encode::Encode; use crate::mysql::protocol::TypeId; use crate::mysql::types::MySqlTypeInfo; use crate::mysql::MySql; -use crate::types::HasSqlType; +use crate::types::Type; -impl HasSqlType for MySql { +impl Type for MySql { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::new(TypeId::TINY_INT) } diff --git a/sqlx-core/src/mysql/types/bytes.rs b/sqlx-core/src/mysql/types/bytes.rs index 8f3bf7d5..ec4429d9 100644 --- a/sqlx-core/src/mysql/types/bytes.rs +++ b/sqlx-core/src/mysql/types/bytes.rs @@ -6,9 +6,9 @@ use crate::mysql::io::{BufExt, BufMutExt}; use crate::mysql::protocol::TypeId; use crate::mysql::types::MySqlTypeInfo; use crate::mysql::MySql; -use crate::types::HasSqlType; +use crate::types::Type; -impl HasSqlType<[u8]> for MySql { +impl Type<[u8]> for MySql { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo { id: TypeId::TEXT, @@ -19,9 +19,9 @@ impl HasSqlType<[u8]> for MySql { } } -impl HasSqlType> for MySql { +impl Type> for MySql { fn type_info() -> MySqlTypeInfo { - >::type_info() + >::type_info() } } diff --git a/sqlx-core/src/mysql/types/chrono.rs b/sqlx-core/src/mysql/types/chrono.rs index 5c1f92d1..b9ef5446 100644 --- a/sqlx-core/src/mysql/types/chrono.rs +++ b/sqlx-core/src/mysql/types/chrono.rs @@ -9,9 +9,9 @@ use crate::io::{Buf, BufMut}; use crate::mysql::protocol::TypeId; use crate::mysql::types::MySqlTypeInfo; use crate::mysql::MySql; -use crate::types::HasSqlType; +use crate::types::Type; -impl HasSqlType> for MySql { +impl Type> for MySql { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::new(TypeId::TIMESTAMP) } @@ -31,7 +31,7 @@ impl Decode for DateTime { } } -impl HasSqlType for MySql { +impl Type for MySql { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::new(TypeId::TIME) } @@ -80,7 +80,7 @@ impl Decode for NaiveTime { } } -impl HasSqlType for MySql { +impl Type for MySql { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::new(TypeId::DATE) } @@ -104,7 +104,7 @@ impl Decode for NaiveDate { } } -impl HasSqlType for MySql { +impl Type for MySql { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::new(TypeId::DATETIME) } diff --git a/sqlx-core/src/mysql/types/float.rs b/sqlx-core/src/mysql/types/float.rs index f708424f..62250ff7 100644 --- a/sqlx-core/src/mysql/types/float.rs +++ b/sqlx-core/src/mysql/types/float.rs @@ -3,7 +3,7 @@ use crate::encode::Encode; use crate::mysql::protocol::TypeId; use crate::mysql::types::MySqlTypeInfo; use crate::mysql::MySql; -use crate::types::HasSqlType; +use crate::types::Type; /// The equivalent MySQL type for `f32` is `FLOAT`. /// @@ -18,7 +18,7 @@ use crate::types::HasSqlType; /// // (This is expected behavior for floating points and happens both in Rust and in MySQL) /// assert_ne!(10.2f32 as f64, 10.2f64); /// ``` -impl HasSqlType for MySql { +impl Type for MySql { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::new(TypeId::FLOAT) } @@ -40,7 +40,7 @@ impl Decode for f32 { /// /// Note that `DOUBLE` is a floating-point type and cannot represent some fractional values /// exactly. -impl HasSqlType for MySql { +impl Type for MySql { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::new(TypeId::DOUBLE) } diff --git a/sqlx-core/src/mysql/types/int.rs b/sqlx-core/src/mysql/types/int.rs index 05d4c849..a5d63146 100644 --- a/sqlx-core/src/mysql/types/int.rs +++ b/sqlx-core/src/mysql/types/int.rs @@ -6,9 +6,9 @@ use crate::io::{Buf, BufMut}; use crate::mysql::protocol::TypeId; use crate::mysql::types::MySqlTypeInfo; use crate::mysql::MySql; -use crate::types::HasSqlType; +use crate::types::Type; -impl HasSqlType for MySql { +impl Type for MySql { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::new(TypeId::TINY_INT) } @@ -26,7 +26,7 @@ impl Decode for i8 { } } -impl HasSqlType for MySql { +impl Type for MySql { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::new(TypeId::SMALL_INT) } @@ -44,7 +44,7 @@ impl Decode for i16 { } } -impl HasSqlType for MySql { +impl Type for MySql { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::new(TypeId::INT) } @@ -62,7 +62,7 @@ impl Decode for i32 { } } -impl HasSqlType for MySql { +impl Type for MySql { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::new(TypeId::BIG_INT) } diff --git a/sqlx-core/src/mysql/types/str.rs b/sqlx-core/src/mysql/types/str.rs index bafcd8e7..46087bf9 100644 --- a/sqlx-core/src/mysql/types/str.rs +++ b/sqlx-core/src/mysql/types/str.rs @@ -8,9 +8,9 @@ use crate::mysql::io::{BufExt, BufMutExt}; use crate::mysql::protocol::TypeId; use crate::mysql::types::MySqlTypeInfo; use crate::mysql::MySql; -use crate::types::HasSqlType; +use crate::types::Type; -impl HasSqlType for MySql { +impl Type for MySql { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo { id: TypeId::TEXT, @@ -28,9 +28,9 @@ impl Encode for str { } // TODO: Do we need the [HasSqlType] for String -impl HasSqlType for MySql { +impl Type for MySql { fn type_info() -> MySqlTypeInfo { - >::type_info() + >::type_info() } } diff --git a/sqlx-core/src/mysql/types/uint.rs b/sqlx-core/src/mysql/types/uint.rs index 81c57f1e..c6db5a72 100644 --- a/sqlx-core/src/mysql/types/uint.rs +++ b/sqlx-core/src/mysql/types/uint.rs @@ -6,9 +6,9 @@ use crate::io::{Buf, BufMut}; use crate::mysql::protocol::TypeId; use crate::mysql::types::MySqlTypeInfo; use crate::mysql::MySql; -use crate::types::HasSqlType; +use crate::types::Type; -impl HasSqlType for MySql { +impl Type for MySql { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::unsigned(TypeId::TINY_INT) } @@ -26,7 +26,7 @@ impl Decode for u8 { } } -impl HasSqlType for MySql { +impl Type for MySql { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::unsigned(TypeId::SMALL_INT) } @@ -44,7 +44,7 @@ impl Decode for u16 { } } -impl HasSqlType for MySql { +impl Type for MySql { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::unsigned(TypeId::INT) } @@ -62,7 +62,7 @@ impl Decode for u32 { } } -impl HasSqlType for MySql { +impl Type for MySql { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::unsigned(TypeId::BIG_INT) } diff --git a/sqlx-core/src/pool/conn.rs b/sqlx-core/src/pool/conn.rs index f99a02c3..fb7235b2 100644 --- a/sqlx-core/src/pool/conn.rs +++ b/sqlx-core/src/pool/conn.rs @@ -1,10 +1,11 @@ -use crate::{Connect, Connection}; +use crate::{Connect, Connection, Executor}; use futures_core::future::BoxFuture; use std::ops::{Deref, DerefMut}; use std::sync::Arc; use std::time::Instant; use super::inner::{DecrementSizeGuard, SharedPool}; +use crate::describe::Describe; /// A connection checked out from [`Pool`][crate::Pool]. /// @@ -68,6 +69,20 @@ where live.float(&self.pool).into_idle().close().await }) } + + #[inline] + fn ping(&mut self) -> BoxFuture> { + Box::pin(self.deref_mut().ping()) + } + + #[doc(hidden)] + #[inline] + fn describe<'e, 'q: 'e>( + &'e mut self, + query: &'q str, + ) -> BoxFuture<'e, crate::Result>> { + Box::pin(self.deref_mut().describe(query)) + } } /// Returns the connection to the [`Pool`][crate::Pool] it was checked-out from. @@ -168,8 +183,7 @@ impl<'s, C> Floating<'s, Idle> { where C: Connection, { - // TODO self.live.raw.ping().await - todo!() + self.live.raw.ping().await } pub fn into_live(self) -> Floating<'s, Live> { diff --git a/sqlx-core/src/pool/executor.rs b/sqlx-core/src/pool/executor.rs index b0afa616..191843a0 100644 --- a/sqlx-core/src/pool/executor.rs +++ b/sqlx-core/src/pool/executor.rs @@ -8,84 +8,89 @@ use crate::{ describe::Describe, executor::Executor, pool::Pool, - Database, + Cursor, Database, }; use super::PoolConnection; use crate::database::HasCursor; use crate::executor::Execute; -impl<'p, C> Executor<'p> for &'p Pool +impl<'p, C, DB> Executor<'p> for &'p Pool where - C: Connect, + C: Connect, + DB: Database, + DB: for<'c, 'q> HasCursor<'c, 'q>, + for<'con> &'con mut C: Executor<'con>, +{ + type Database = DB; + + fn execute<'q, E>(self, query: E) -> >::Cursor + where + E: Execute<'q, Self::Database>, + { + DB::Cursor::from_pool(self, query) + } + + #[inline] + fn execute_by_ref<'q, 'e, E>( + &'e mut self, + query: E, + ) -> >::Cursor + where + E: Execute<'q, Self::Database>, + { + self.execute(query) + } +} + +impl<'c, C, DB> Executor<'c> for &'c mut PoolConnection +where + C: Connect, + DB: Database, + DB: for<'c2, 'q> HasCursor<'c2, 'q, Database = DB>, for<'con> &'con mut C: Executor<'con>, { type Database = C::Database; - fn execute<'q, E>(self, query: E) -> >::Cursor + fn execute<'q, E>(self, query: E) -> >::Cursor where E: Execute<'q, Self::Database>, { - todo!() + DB::Cursor::from_connection(&mut **self, query) } + #[inline] fn execute_by_ref<'q, 'e, E>( &'e mut self, query: E, - ) -> >::Cursor + ) -> >::Cursor where E: Execute<'q, Self::Database>, { - todo!() + self.execute(query) } } -impl<'c, C> Executor<'c> for &'c mut PoolConnection +impl Executor<'static> for PoolConnection where - C: Connect, - for<'con> &'con mut C: Executor<'con>, + C: Connect, + DB: Database, + DB: for<'c, 'q> HasCursor<'c, 'q, Database = DB>, { - type Database = C::Database; + type Database = DB; - fn execute<'q, E>(self, query: E) -> >::Cursor + fn execute<'q, E>(self, query: E) -> >::Cursor where E: Execute<'q, Self::Database>, { - todo!() + DB::Cursor::from_connection(self, query) } - fn execute_by_ref<'q, 'e, E>( - &'e mut self, - query: E, - ) -> >::Cursor + #[inline] + fn execute_by_ref<'q, 'e, E>(&'e mut self, query: E) -> >::Cursor where E: Execute<'q, Self::Database>, { - todo!() - } -} - -impl Executor<'static> for PoolConnection -where - C: Connect, - // for<'con> &'con mut C: Executor<'con>, -{ - type Database = C::Database; - - fn execute<'q, E>(self, query: E) -> >::Cursor - where - E: Execute<'q, Self::Database>, - { - unimplemented!() - } - - fn execute_by_ref<'q, 'e, E>( - &'e mut self, - query: E, - ) -> >::Cursor - where - E: Execute<'q, Self::Database>, - { - todo!() + DB::Cursor::from_connection(&mut **self, query) } } diff --git a/sqlx-core/src/pool/mod.rs b/sqlx-core/src/pool/mod.rs index 373d3de8..9c2167d6 100644 --- a/sqlx-core/src/pool/mod.rs +++ b/sqlx-core/src/pool/mod.rs @@ -20,13 +20,15 @@ mod inner; mod options; pub use self::options::Builder; +use crate::Database; /// A pool of database connections. pub struct Pool(Arc>); -impl Pool +impl Pool where - C: Connect, + C: Connect, + DB: Database, { /// Creates a connection pool with the default configuration. /// diff --git a/sqlx-core/src/pool/options.rs b/sqlx-core/src/pool/options.rs index 318d5434..b0f58493 100644 --- a/sqlx-core/src/pool/options.rs +++ b/sqlx-core/src/pool/options.rs @@ -2,6 +2,7 @@ use std::{marker::PhantomData, time::Duration}; use super::Pool; use crate::connection::Connect; +use crate::Database; /// Builder for [Pool]. pub struct Builder { @@ -9,7 +10,11 @@ pub struct Builder { options: Options, } -impl Builder { +impl Builder +where + C: Connect, + DB: Database, +{ /// Get a new builder with default options. /// /// See the source of this method for current defaults. @@ -108,7 +113,11 @@ impl Builder { } } -impl Default for Builder { +impl Default for Builder +where + C: Connect, + DB: Database, +{ fn default() -> Self { Self::new() } diff --git a/sqlx-core/src/postgres/arguments.rs b/sqlx-core/src/postgres/arguments.rs index 8795833b..5ead9c21 100644 --- a/sqlx-core/src/postgres/arguments.rs +++ b/sqlx-core/src/postgres/arguments.rs @@ -3,7 +3,7 @@ use byteorder::{ByteOrder, NetworkEndian}; use crate::arguments::Arguments; use crate::encode::{Encode, IsNull}; use crate::io::BufMut; -use crate::types::HasSqlType; +use crate::types::Type; use crate::Postgres; #[derive(Default)] @@ -25,14 +25,13 @@ impl Arguments for PgArguments { fn add(&mut self, value: T) where - Self::Database: HasSqlType, + T: Type, T: Encode, { // TODO: When/if we receive types that do _not_ support BINARY, we need to check here // TODO: There is no need to be explicit unless we are expecting mixed BINARY / TEXT - self.types - .push(>::type_info().id.0); + self.types.push(>::type_info().id.0); let pos = self.values.len(); diff --git a/sqlx-core/src/postgres/connection.rs b/sqlx-core/src/postgres/connection.rs index 7be9102b..8e796596 100644 --- a/sqlx-core/src/postgres/connection.rs +++ b/sqlx-core/src/postgres/connection.rs @@ -1,16 +1,26 @@ use std::convert::TryInto; +use std::ops::Range; use byteorder::NetworkEndian; use futures_core::future::BoxFuture; -use std::net::Shutdown; +use futures_core::Future; +use futures_util::TryFutureExt; use crate::cache::StatementCache; use crate::connection::{Connect, Connection}; +use crate::describe::{Column, Describe}; use crate::io::{Buf, BufStream, MaybeTlsStream}; -use crate::postgres::protocol::{self, Authentication, Decode, Encode, Message, StatementId}; -use crate::postgres::{sasl, PgError}; +use crate::postgres::protocol::{ + self, Authentication, AuthenticationMd5, AuthenticationSasl, Decode, Encode, Message, + ParameterDescription, PasswordMessage, RowDescription, StartupMessage, StatementId, Terminate, +}; +use crate::postgres::sasl; +use crate::postgres::stream::PgStream; +use crate::postgres::{PgError, PgTypeInfo}; use crate::url::Url; -use crate::{Postgres, Result}; +use crate::{Error, Executor, Postgres}; + +// TODO: TLS /// An asynchronous connection to a [Postgres][super::Postgres] database. /// @@ -73,301 +83,279 @@ use crate::{Postgres, Result}; /// against the hostname in the server certificate, so they must be the same for the TLS /// upgrade to succeed. pub struct PgConnection { - pub(super) stream: BufStream, - - // Map of query to statement id - pub(super) statement_cache: StatementCache, - - // Next statement id + pub(super) stream: PgStream, pub(super) next_statement_id: u32, + pub(super) is_ready: bool, - // Process ID of the Backend - process_id: u32, - - // Backend-unique key to use to send a cancel query message to the server - secret_key: u32, - - // Is there a query in progress; are we ready to continue - pub(super) ready: bool, + // TODO: Think of a better way to do this, better name perhaps? + pub(super) data_row_values_buf: Vec>>, } -impl PgConnection { - // https://www.postgresql.org/docs/12/protocol-flow.html#id-1.10.5.7.3 - async fn startup(&mut self, url: &Url) -> Result<()> { - // Defaults to postgres@.../postgres - let username = url.username().unwrap_or("postgres"); - let database = url.database().unwrap_or("postgres"); +// https://www.postgresql.org/docs/12/protocol-flow.html#id-1.10.5.7.3 +async fn startup(stream: &mut PgStream, url: &Url) -> crate::Result<()> { + // Defaults to postgres@.../postgres + let username = url.username().unwrap_or("postgres"); + let database = url.database().unwrap_or("postgres"); - // See this doc for more runtime parameters - // https://www.postgresql.org/docs/12/runtime-config-client.html - let params = &[ - ("user", username), - ("database", database), - // Sets the display format for date and time values, - // as well as the rules for interpreting ambiguous date input values. - ("DateStyle", "ISO, MDY"), - // Sets the display format for interval values. - ("IntervalStyle", "iso_8601"), - // Sets the time zone for displaying and interpreting time stamps. - ("TimeZone", "UTC"), - // Adjust postgres to return percise values for floats - // NOTE: This is default in postgres 12+ - ("extra_float_digits", "3"), - // Sets the client-side encoding (character set). - ("client_encoding", "UTF-8"), - ]; + // See this doc for more runtime parameters + // https://www.postgresql.org/docs/12/runtime-config-client.html + let params = &[ + ("user", username), + ("database", database), + // Sets the display format for date and time values, + // as well as the rules for interpreting ambiguous date input values. + ("DateStyle", "ISO, MDY"), + // Sets the display format for interval values. + ("IntervalStyle", "iso_8601"), + // Sets the time zone for displaying and interpreting time stamps. + ("TimeZone", "UTC"), + // Adjust postgres to return percise values for floats + // NOTE: This is default in postgres 12+ + ("extra_float_digits", "3"), + // Sets the client-side encoding (character set). + ("client_encoding", "UTF-8"), + ]; - protocol::StartupMessage { params }.encode(self.stream.buffer_mut()); - self.stream.flush().await?; + stream.write(StartupMessage { params }); + stream.flush().await?; - while let Some(message) = self.receive().await? { - match message { - Message::Authentication(auth) => { - match *auth { - protocol::Authentication::Ok => { - // Do nothing. No password is needed to continue. - } + loop { + match stream.read().await? { + Message::Authentication => match Authentication::read(stream.buffer())? { + Authentication::Ok => { + // do nothing. no password is needed to continue. + } - protocol::Authentication::ClearTextPassword => { - protocol::PasswordMessage::ClearText( - &url.password().unwrap_or_default(), - ) - .encode(self.stream.buffer_mut()); + Authentication::CleartextPassword => { + stream.write(PasswordMessage::ClearText( + &url.password().unwrap_or_default(), + )); - self.stream.flush().await?; - } + stream.flush().await?; + } - protocol::Authentication::Md5Password { salt } => { - protocol::PasswordMessage::Md5 { - password: &url.password().unwrap_or_default(), - user: username, - salt, - } - .encode(self.stream.buffer_mut()); + Authentication::Md5Password => { + // TODO: Just reference the salt instead of returning a stack array + // TODO: Better way to make sure we skip the first 4 bytes here + let data = AuthenticationMd5::read(&stream.buffer()[4..])?; - self.stream.flush().await?; - } + stream.write(PasswordMessage::Md5 { + password: &url.password().unwrap_or_default(), + user: username, + salt: data.salt, + }); - protocol::Authentication::Sasl { mechanisms } => { - let mut has_sasl: bool = false; - let mut has_sasl_plus: bool = false; + stream.flush().await?; + } - for mechanism in &*mechanisms { - match &**mechanism { - "SCRAM-SHA-256" => { - has_sasl = true; - } + Authentication::Sasl => { + // TODO: Make this iterative for traversing the mechanisms to remove the allocation + // TODO: Better way to make sure we skip the first 4 bytes here + let data = AuthenticationSasl::read(&stream.buffer()[4..])?; - "SCRAM-SHA-256-PLUS" => { - has_sasl_plus = true; - } + let mut has_sasl: bool = false; + let mut has_sasl_plus: bool = false; - _ => { - log::info!("unsupported auth mechanism: {}", mechanism); - } - } + for mechanism in &*data.mechanisms { + match &**mechanism { + "SCRAM-SHA-256" => { + has_sasl = true; } - if has_sasl || has_sasl_plus { - // TODO: Handle -PLUS differently if we're in a TLS stream - sasl::authenticate( - self, - username, - &url.password().unwrap_or_default(), - ) - .await?; - } else { - return Err(protocol_err!( - "unsupported SASL auth mechanisms: {:?}", - mechanisms - ) - .into()); + "SCRAM-SHA-256-PLUS" => { + has_sasl_plus = true; + } + + _ => { + log::info!("unsupported auth mechanism: {}", mechanism); } } + } - auth => { - return Err(protocol_err!( - "requires unimplemented authentication method: {:?}", - auth - ) - .into()); - } + if has_sasl || has_sasl_plus { + // TODO: Handle -PLUS differently if we're in a TLS stream + sasl::authenticate(stream, username, &url.password().unwrap_or_default()) + .await?; + } else { + return Err(protocol_err!( + "unsupported SASL auth mechanisms: {:?}", + data.mechanisms + ) + .into()); } } - Message::BackendKeyData(body) => { - self.process_id = body.process_id; - self.secret_key = body.secret_key; + auth => { + return Err( + protocol_err!("requested unsupported authentication: {:?}", auth).into(), + ); } + }, - Message::ReadyForQuery(_) => { - // Connection fully established and ready to receive queries. + Message::BackendKeyData => { + // do nothing. we do not care about the server values here. + // todo: we should care and store these on the connection + } + + Message::ParameterStatus => { + // do nothing. we do not care about the server values here. + } + + Message::ReadyForQuery => { + // done. connection is now fully established and can accept + // queries for execution. + break; + } + + type_ => { + return Err(protocol_err!("unexpected message: {:?}", type_).into()); + } + } + } + + Ok(()) +} + +// https://www.postgresql.org/docs/12/protocol-flow.html#id-1.10.5.7.10 +async fn terminate(mut stream: PgStream) -> crate::Result<()> { + stream.write(Terminate); + stream.flush().await?; + stream.shutdown()?; + + Ok(()) +} + +impl PgConnection { + pub(super) async fn new(url: crate::Result) -> crate::Result { + let url = url?; + let mut stream = PgStream::new(&url).await?; + + startup(&mut stream, &url).await?; + + Ok(Self { + stream, + data_row_values_buf: Vec::new(), + next_statement_id: 1, + is_ready: true, + }) + } + + pub(super) async fn wait_until_ready(&mut self) -> crate::Result<()> { + // depending on how the previous query finished we may need to continue + // pulling messages from the stream until we receive a [ReadyForQuery] message + + // postgres sends the [ReadyForQuery] message when it's fully complete with processing + // the previous query + + if !self.is_ready { + loop { + if let Message::ReadyForQuery = self.stream.read().await? { + // we are now ready to go + self.is_ready = true; break; } - - message => { - return Err(protocol_err!("received unexpected message: {:?}", message).into()); - } } } Ok(()) } - // https://www.postgresql.org/docs/devel/protocol-flow.html#id-1.10.5.7.10 - async fn terminate(mut self) -> Result<()> { - protocol::Terminate.encode(self.stream.buffer_mut()); + async fn describe<'e, 'q: 'e>( + &'e mut self, + query: &'q str, + ) -> crate::Result> { + let statement = self.write_prepare(query, &Default::default()); + + self.write_describe(protocol::Describe::Statement(statement)); + self.write_sync(); self.stream.flush().await?; - self.stream.stream.shutdown(Shutdown::Both)?; + self.wait_until_ready().await?; - Ok(()) - } - - // Wait and return the next message to be received from Postgres. - pub(super) async fn receive(&mut self) -> Result> { - loop { - // Read the message header (id + len) - let mut header = ret_if_none!(self.stream.peek(5).await?); - - let id = header.get_u8()?; - let len = (header.get_u32::()? - 4) as usize; - - // Read the message body - self.stream.consume(5); - let body = ret_if_none!(self.stream.peek(len).await?); - - let message = match id { - b'N' | b'E' => Message::Response(Box::new(protocol::Response::decode(body)?)), - b'D' => Message::DataRow(protocol::DataRow::decode(body)?), - b'S' => { - Message::ParameterStatus(Box::new(protocol::ParameterStatus::decode(body)?)) - } - b'Z' => Message::ReadyForQuery(protocol::ReadyForQuery::decode(body)?), - b'R' => Message::Authentication(Box::new(protocol::Authentication::decode(body)?)), - b'K' => Message::BackendKeyData(protocol::BackendKeyData::decode(body)?), - b'C' => Message::CommandComplete(protocol::CommandComplete::decode(body)?), - b'A' => Message::NotificationResponse(Box::new( - protocol::NotificationResponse::decode(body)?, - )), - b'1' => Message::ParseComplete, - b'2' => Message::BindComplete, - b'3' => Message::CloseComplete, - b'n' => Message::NoData, - b's' => Message::PortalSuspended, - b't' => Message::ParameterDescription(Box::new( - protocol::ParameterDescription::decode(body)?, - )), - b'T' => Message::RowDescription(Box::new(protocol::RowDescription::decode(body)?)), - - id => { - return Err(protocol_err!("received unknown message id: {:?}", id).into()); - } - }; - - self.stream.consume(len); - - match message { - Message::ParameterStatus(_body) => { - // TODO: not sure what to do with these yet + let params = loop { + match self.stream.read().await? { + Message::ParseComplete => { + // ignore complete messsage + // continue } - Message::Response(body) => { - if body.severity.is_error() { - // This is an error, stop the world and bubble as an error - return Err(PgError(body).into()); - } else { - // This is a _warning_ - // TODO: Log the warning - } + Message::ParameterDescription => { + break ParameterDescription::read(self.stream.buffer())?; } message => { - return Ok(Some(message)); + return Err(protocol_err!( + "expected ParameterDescription; received {:?}", + message + ) + .into()); } - } - } - } -} - -impl PgConnection { - pub(super) async fn establish(url: Result) -> Result { - let url = url?; - - let stream = MaybeTlsStream::connect(&url, 5432).await?; - let mut self_ = Self { - stream: BufStream::new(stream), - process_id: 0, - secret_key: 0, - // Important to start at 1 as 0 means "unnamed" in our protocol - next_statement_id: 1, - statement_cache: StatementCache::new(), - ready: true, + }; }; - let ssl_mode = url.get_param("sslmode").unwrap_or("prefer".into()); + let result = match self.stream.read().await? { + Message::NoData => None, + Message::RowDescription => Some(RowDescription::read(self.stream.buffer())?), - match &*ssl_mode { - // TODO: on "allow" retry with TLS if startup fails - "disable" | "allow" => (), - - #[cfg(feature = "tls")] - "prefer" => { - if !self_.try_ssl(&url, true, true).await? { - log::warn!("server does not support TLS, falling back to unsecured connection") - } - } - - #[cfg(not(feature = "tls"))] - "prefer" => log::info!("compiled without TLS, skipping upgrade"), - - #[cfg(feature = "tls")] - "require" | "verify-ca" | "verify-full" => { - if !self_ - .try_ssl( - &url, - ssl_mode == "require", // false for both verify-ca and verify-full - ssl_mode != "verify-full", // false for only verify-full - ) - .await? - { - return Err(tls_err!("Postgres server does not support TLS").into()); - } - } - - #[cfg(not(feature = "tls"))] - "require" | "verify-ca" | "verify-full" => { - return Err(tls_err!( - "sslmode {:?} unsupported; SQLx was compiled without `tls` feature", - ssl_mode + message => { + return Err(protocol_err!( + "expected RowDescription or NoData; received {:?}", + message ) - .into()) + .into()); } - _ => return Err(tls_err!("unknown `sslmode` value: {:?}", ssl_mode).into()), - } + }; - self_.stream.clear_bufs(); - - self_.startup(&url).await?; - - Ok(self_) + Ok(Describe { + param_types: params + .ids + .iter() + .map(|id| PgTypeInfo::new(*id)) + .collect::>() + .into_boxed_slice(), + result_columns: result + .map(|r| r.fields) + .unwrap_or_default() + .into_vec() + .into_iter() + // TODO: Should [Column] just wrap [protocol::Field] ? + .map(|field| Column { + name: field.name, + table_id: field.table_id, + type_info: PgTypeInfo::new(field.type_id), + }) + .collect::>() + .into_boxed_slice(), + }) } } impl Connect for PgConnection { - fn connect(url: T) -> BoxFuture<'static, Result> + fn connect(url: T) -> BoxFuture<'static, crate::Result> where T: TryInto, Self: Sized, { - Box::pin(PgConnection::establish(url.try_into())) + Box::pin(PgConnection::new(url.try_into())) } } impl Connection for PgConnection { type Database = Postgres; - fn close(self) -> BoxFuture<'static, Result<()>> { - Box::pin(self.terminate()) + fn close(self) -> BoxFuture<'static, crate::Result<()>> { + Box::pin(terminate(self.stream)) + } + + fn ping(&mut self) -> BoxFuture> { + Box::pin(self.execute("SELECT 1").map_ok(|_| ())) + } + + #[doc(hidden)] + fn describe<'e, 'q: 'e>( + &'e mut self, + query: &'q str, + ) -> BoxFuture<'e, crate::Result>> { + Box::pin(self.describe(query)) } } diff --git a/sqlx-core/src/postgres/cursor.rs b/sqlx-core/src/postgres/cursor.rs index ff5fa2ea..222af86f 100644 --- a/sqlx-core/src/postgres/cursor.rs +++ b/sqlx-core/src/postgres/cursor.rs @@ -1,56 +1,329 @@ use std::future::Future; +use std::mem; use std::pin::Pin; +use std::sync::Arc; use std::task::{Context, Poll}; +use async_stream::try_stream; use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; -use crate::cursor::Cursor; +use crate::connection::{ConnectionSource, MaybeOwnedConnection}; +use crate::cursor::{Cursor, MapRowFn}; use crate::database::HasRow; -use crate::postgres::protocol::StatementId; -use crate::postgres::PgConnection; -use crate::Postgres; +use crate::executor::Execute; +use crate::pool::{Pool, PoolConnection}; +use crate::postgres::protocol::{CommandComplete, DataRow, Message, StatementId}; +use crate::postgres::{PgArguments, PgConnection, PgRow}; +use crate::{Database, Postgres}; -pub struct PgCursor<'a> { - statement: StatementId, - connection: &'a mut PgConnection, +enum State<'c, 'q> { + Query(&'q str, Option), + NextRow, + + // Used for `impl Future` + Resolve(BoxFuture<'c, crate::Result>>), + AffectedRows(BoxFuture<'c, crate::Result>), } -impl<'a> PgCursor<'a> { - pub(super) fn from_connection( - connection: &'a mut PgConnection, - statement: StatementId, - ) -> Self { +pub struct PgCursor<'c, 'q> { + source: ConnectionSource<'c, PgConnection>, + state: State<'c, 'q>, +} + +impl<'c, 'q> Cursor<'c, 'q> for PgCursor<'c, 'q> { + type Database = Postgres; + + #[doc(hidden)] + fn from_pool(pool: &Pool<::Connection>, query: E) -> Self + where + Self: Sized, + E: Execute<'q, Self::Database>, + { + let (query, arguments) = query.into_parts(); + Self { - connection, - statement, + // note: pool is internally reference counted + source: ConnectionSource::Pool(pool.clone()), + state: State::Query(query, arguments), + } + } + + #[doc(hidden)] + fn from_connection(conn: C, query: E) -> Self + where + Self: Sized, + C: Into::Connection>>, + E: Execute<'q, Self::Database>, + { + let (query, arguments) = query.into_parts(); + + Self { + // note: pool is internally reference counted + source: ConnectionSource::Connection(conn.into()), + state: State::Query(query, arguments), + } + } + + fn first(self) -> BoxFuture<'c, crate::Result>::Row>>> + where + 'q: 'c, + { + Box::pin(first(self)) + } + + fn next(&mut self) -> BoxFuture>::Row>>> { + Box::pin(next(self)) + } + + fn map(mut self, f: F) -> BoxStream<'c, crate::Result> + where + F: MapRowFn, + T: 'c + Send + Unpin, + 'q: 'c, + { + Box::pin(try_stream! { + while let Some(row) = self.next().await? { + yield f.call(row); + } + }) + } +} + +impl<'s, 'q> Future for PgCursor<'s, 'q> { + type Output = crate::Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + loop { + match &mut self.state { + State::Query(q, arguments) => { + // todo: existential types can remove both the boxed futures + // and this allocation + let query = q.to_owned(); + let arguments = mem::take(arguments); + + self.state = State::Resolve(Box::pin(resolve( + mem::take(&mut self.source), + query, + arguments, + ))); + } + + State::Resolve(fut) => { + match fut.as_mut().poll(cx) { + Poll::Pending => { + return Poll::Pending; + } + + Poll::Ready(conn) => { + let conn = conn?; + + self.state = State::AffectedRows(Box::pin(affected_rows(conn))); + + // continue + } + } + } + + State::NextRow => { + panic!("PgCursor must not be polled after being used"); + } + + State::AffectedRows(fut) => { + return fut.as_mut().poll(cx); + } + } } } } -impl<'a> Cursor<'a> for PgCursor<'a> { - type Database = Postgres; +// write out query to the connection stream +async fn write( + conn: &mut PgConnection, + query: &str, + arguments: Option, +) -> crate::Result<()> { + // TODO: Handle [arguments] being None. This should be a SIMPLE query. + let arguments = arguments.unwrap(); - fn first(self) -> BoxFuture<'a, crate::Result::Row>>> { - todo!() - } + // Check the statement cache for a statement ID that matches the given query + // If it doesn't exist, we generate a new statement ID and write out [Parse] to the + // connection command buffer + let statement = conn.write_prepare(query, &arguments); - fn next(&mut self) -> BoxFuture::Row>>> { - todo!() - } + // Next, [Bind] attaches the arguments to the statement and creates a named portal + conn.write_bind("", statement, &arguments); - fn map(self, f: F) -> BoxStream<'a, crate::Result> - where - F: Fn(::Row) -> T, - { - todo!() - } + // Next, [Describe] will return the expected result columns and types + // Conditionally run [Describe] only if the results have not been cached + // if !self.statement_cache.has_columns(statement) { + // self.write_describe(protocol::Describe::Portal("")); + // } + + // Next, [Execute] then executes the named portal + conn.write_execute("", 0); + + // Finally, [Sync] asks postgres to process the messages that we sent and respond with + // a [ReadyForQuery] message when it's completely done. Theoretically, we could send + // dozens of queries before a [Sync] and postgres can handle that. Execution on the server + // is still serial but it would reduce round-trips. Some kind of builder pattern that is + // termed batching might suit this. + conn.write_sync(); + + conn.wait_until_ready().await?; + + conn.stream.flush().await?; + conn.is_ready = false; + + Ok(()) } -impl<'a> Future for PgCursor<'a> { - type Output = crate::Result; +async fn resolve( + mut source: ConnectionSource<'_, PgConnection>, + query: String, + arguments: Option, +) -> crate::Result> { + let mut conn = source.resolve_by_ref().await?; - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - todo!() - } + write(&mut *conn, &query, arguments).await?; + + Ok(source.into_connection()) +} + +async fn affected_rows(mut conn: MaybeOwnedConnection<'_, PgConnection>) -> crate::Result { + conn.wait_until_ready().await?; + + conn.stream.flush().await?; + conn.is_ready = false; + + let mut rows = 0; + + loop { + match conn.stream.read().await? { + Message::ParseComplete | Message::BindComplete => { + // ignore x_complete messages + } + + Message::DataRow => { + // ignore rows + // TODO: should we log or something? + } + + Message::CommandComplete => { + rows += CommandComplete::read(conn.stream.buffer())?.affected_rows; + } + + Message::ReadyForQuery => { + // done + break; + } + + message => { + return Err(protocol_err!("unexpected message: {:?}", message).into()); + } + } + } + + Ok(rows) +} + +async fn next<'a, 'c: 'a, 'q: 'a>( + cursor: &'a mut PgCursor<'c, 'q>, +) -> crate::Result>> { + let mut conn = cursor.source.resolve_by_ref().await?; + + match cursor.state { + State::Query(q, ref mut arguments) => { + // write out the query to the connection + write(&mut *conn, q, arguments.take()).await?; + + // next time we come through here, skip this block + cursor.state = State::NextRow; + } + + State::Resolve(_) | State::AffectedRows(_) => { + panic!("`PgCursor` must not be used after being polled"); + } + + State::NextRow => { + // grab the next row + } + } + + loop { + match conn.stream.read().await? { + Message::ParseComplete | Message::BindComplete => { + // ignore x_complete messages + } + + Message::CommandComplete => { + // no more rows + break; + } + + Message::DataRow => { + let data = DataRow::read(&mut *conn)?; + + return Ok(Some(PgRow { + connection: conn, + columns: Arc::default(), + data, + })); + } + + message => { + return Err(protocol_err!("unexpected message: {:?}", message).into()); + } + } + } + + Ok(None) +} + +async fn first<'c, 'q>(mut cursor: PgCursor<'c, 'q>) -> crate::Result>> { + let mut conn = cursor.source.resolve().await?; + + match cursor.state { + State::Query(q, ref mut arguments) => { + // write out the query to the connection + write(&mut conn, q, arguments.take()).await?; + } + + State::NextRow => { + // just grab the next row as the first + } + + State::Resolve(_) | State::AffectedRows(_) => { + panic!("`PgCursor` must not be used after being polled"); + } + } + + loop { + match conn.stream.read().await? { + Message::ParseComplete | Message::BindComplete => { + // ignore x_complete messages + } + + Message::CommandComplete => { + // no more rows + break; + } + + Message::DataRow => { + let data = DataRow::read(&mut conn)?; + + return Ok(Some(PgRow { + connection: conn, + columns: Arc::default(), + data, + })); + } + + message => { + return Err(protocol_err!("unexpected message: {:?}", message).into()); + } + } + } + + Ok(None) } diff --git a/sqlx-core/src/postgres/database.rs b/sqlx-core/src/postgres/database.rs index 0bfc0e6e..79594045 100644 --- a/sqlx-core/src/postgres/database.rs +++ b/sqlx-core/src/postgres/database.rs @@ -13,18 +13,18 @@ impl Database for Postgres { type TableId = u32; } -impl HasRow for Postgres { +impl<'a> HasRow<'a> for Postgres { // TODO: Can we drop the `type Database = _` type Database = Postgres; - type Row = super::PgRow; + type Row = super::PgRow<'a>; } -impl<'a> HasCursor<'a> for Postgres { +impl<'s, 'q> HasCursor<'s, 'q> for Postgres { // TODO: Can we drop the `type Database = _` type Database = Postgres; - type Cursor = super::PgCursor<'a>; + type Cursor = super::PgCursor<'s, 'q>; } impl<'a> HasRawValue<'a> for Postgres { diff --git a/sqlx-core/src/postgres/error.rs b/sqlx-core/src/postgres/error.rs index 84e30b2b..537d0d2d 100644 --- a/sqlx-core/src/postgres/error.rs +++ b/sqlx-core/src/postgres/error.rs @@ -1,7 +1,7 @@ use crate::error::DatabaseError; use crate::postgres::protocol::Response; -pub struct PgError(pub(super) Box); +pub struct PgError(pub(super) Response); impl DatabaseError for PgError { fn message(&self) -> &str { diff --git a/sqlx-core/src/postgres/executor.rs b/sqlx-core/src/postgres/executor.rs index 35c79269..67202b75 100644 --- a/sqlx-core/src/postgres/executor.rs +++ b/sqlx-core/src/postgres/executor.rs @@ -2,37 +2,36 @@ use std::collections::HashMap; use std::io; use std::sync::Arc; +use crate::cursor::Cursor; use crate::executor::{Execute, Executor}; -use crate::postgres::protocol::{self, Encode, Message, StatementId, TypeFormat}; -use crate::postgres::{PgArguments, PgCursor, PgRow, PgTypeInfo, Postgres}; +use crate::postgres::protocol::{self, Encode, StatementId, TypeFormat}; +use crate::postgres::{PgArguments, PgConnection, PgCursor, PgRow, PgTypeInfo, Postgres}; -impl super::PgConnection { - fn write_prepare(&mut self, query: &str, args: &PgArguments) -> StatementId { - if let Some(&id) = self.statement_cache.get(query) { - id - } else { - let id = StatementId(self.next_statement_id); - self.next_statement_id += 1; +impl PgConnection { + pub(crate) fn write_prepare(&mut self, query: &str, args: &PgArguments) -> StatementId { + // TODO: check query cache - protocol::Parse { - statement: id, - query, - param_types: &*args.types, - } - .encode(self.stream.buffer_mut()); + let id = StatementId(self.next_statement_id); - self.statement_cache.put(query.to_owned(), id); + self.next_statement_id += 1; - id - } + self.stream.write(protocol::Parse { + statement: id, + query, + param_types: &*args.types, + }); + + // TODO: write to query cache + + id } - fn write_describe(&mut self, d: protocol::Describe) { - d.encode(self.stream.buffer_mut()) + pub(crate) fn write_describe(&mut self, d: protocol::Describe) { + self.stream.write(d); } - fn write_bind(&mut self, portal: &str, statement: StatementId, args: &PgArguments) { - protocol::Bind { + pub(crate) fn write_bind(&mut self, portal: &str, statement: StatementId, args: &PgArguments) { + self.stream.write(protocol::Bind { portal, statement, formats: &[TypeFormat::Binary], @@ -40,59 +39,30 @@ impl super::PgConnection { values_len: args.types.len() as i16, values: &*args.values, result_formats: &[TypeFormat::Binary], - } - .encode(self.stream.buffer_mut()); + }); } - fn write_execute(&mut self, portal: &str, limit: i32) { - protocol::Execute { portal, limit }.encode(self.stream.buffer_mut()); + pub(crate) fn write_execute(&mut self, portal: &str, limit: i32) { + self.stream.write(protocol::Execute { portal, limit }); } - fn write_sync(&mut self) { - protocol::Sync.encode(self.stream.buffer_mut()); + pub(crate) fn write_sync(&mut self) { + self.stream.write(protocol::Sync); } } impl<'e> Executor<'e> for &'e mut super::PgConnection { type Database = Postgres; - fn execute<'q, E>(self, query: E) -> PgCursor<'e> + fn execute<'q, E>(self, query: E) -> PgCursor<'e, 'q> where E: Execute<'q, Self::Database>, { - let (query, arguments) = query.into_parts(); - - // TODO: Handle [arguments] being None. This should be a SIMPLE query. - let arguments = arguments.unwrap(); - - // Check the statement cache for a statement ID that matches the given query - // If it doesn't exist, we generate a new statement ID and write out [Parse] to the - // connection command buffer - let statement = self.write_prepare(query, &arguments); - - // Next, [Bind] attaches the arguments to the statement and creates a named portal - self.write_bind("", statement, &arguments); - - // Next, [Describe] will return the expected result columns and types - // Conditionally run [Describe] only if the results have not been cached - if !self.statement_cache.has_columns(statement) { - self.write_describe(protocol::Describe::Portal("")); - } - - // Next, [Execute] then executes the named portal - self.write_execute("", 0); - - // Finally, [Sync] asks postgres to process the messages that we sent and respond with - // a [ReadyForQuery] message when it's completely done. Theoretically, we could send - // dozens of queries before a [Sync] and postgres can handle that. Execution on the server - // is still serial but it would reduce round-trips. Some kind of builder pattern that is - // termed batching might suit this. - self.write_sync(); - - PgCursor::from_connection(self, statement) + PgCursor::from_connection(self, query) } - fn execute_by_ref<'q, E>(&mut self, query: E) -> PgCursor<'_> + #[inline] + fn execute_by_ref<'q, E>(&mut self, query: E) -> PgCursor<'_, 'q> where E: Execute<'q, Self::Database>, { diff --git a/sqlx-core/src/postgres/mod.rs b/sqlx-core/src/postgres/mod.rs index 3e9ac4ce..b0954bb6 100644 --- a/sqlx-core/src/postgres/mod.rs +++ b/sqlx-core/src/postgres/mod.rs @@ -17,7 +17,8 @@ mod executor; mod protocol; mod row; mod sasl; -mod tls; +mod stream; +// mod tls; mod types; /// An alias for [`Pool`][crate::Pool], specialized for **Postgres**. diff --git a/sqlx-core/src/postgres/protocol/authentication.rs b/sqlx-core/src/postgres/protocol/authentication.rs index dd167d92..869e5098 100644 --- a/sqlx-core/src/postgres/protocol/authentication.rs +++ b/sqlx-core/src/postgres/protocol/authentication.rs @@ -5,152 +5,176 @@ use std::str; #[derive(Debug)] pub enum Authentication { - /// Authentication was successful. + /// The authentication exchange is successfully completed. Ok, - /// Kerberos V5 authentication is required. + /// The frontend must now take part in a Kerberos V5 authentication dialog (not described + /// here, part of the Kerberos specification) with the server. If this is successful, + /// the server responds with an `AuthenticationOk`, otherwise it responds + /// with an `ErrorResponse`. This is no longer supported. KerberosV5, - /// A clear-text password is required. - ClearTextPassword, + /// The frontend must now send a `PasswordMessage` containing the password in clear-text form. + /// If this is the correct password, the server responds with an `AuthenticationOk`, otherwise it + /// responds with an `ErrorResponse`. + CleartextPassword, - /// An MD5-encrypted password is required. - Md5Password { salt: [u8; 4] }, + /// The frontend must now send a `PasswordMessage` containing the password (with user name) + /// encrypted via MD5, then encrypted again using the 4-byte random salt specified in the + /// `AuthenticationMD5Password` message. If this is the correct password, the server responds + /// with an `AuthenticationOk`, otherwise it responds with an `ErrorResponse`. + Md5Password, - /// An SCM credentials message is required. + /// This response is only possible for local Unix-domain connections on platforms that support + /// SCM credential messages. The frontend must issue an SCM credential message and then + /// send a single data byte. ScmCredential, - /// GSSAPI authentication is required. + /// The frontend must now initiate a GSSAPI negotiation. The frontend will send a + /// `GSSResponse` message with the first part of the GSSAPI data stream in response to this. Gss, - /// SSPI authentication is required. + /// The frontend must now initiate a SSPI negotiation. + /// The frontend will send a GSSResponse with the first part of the SSPI data stream in + /// response to this. Sspi, - /// This message contains GSSAPI or SSPI data. - GssContinue { data: Box<[u8]> }, + /// This message contains the response data from the previous step of GSSAPI + /// or SSPI negotiation. + GssContinue, - /// SASL authentication is required. - /// - /// The message body is a list of SASL authentication mechanisms, - /// in the server's order of preference. - Sasl { mechanisms: Box<[Box]> }, + /// The frontend must now initiate a SASL negotiation, using one of the SASL mechanisms + /// listed in the message. + Sasl, - /// This message contains a SASL challenge. - SaslContinue(SaslContinue), + /// This message contains challenge data from the previous step of SASL negotiation. + SaslContinue, - /// SASL authentication has completed. - SaslFinal { data: Box<[u8]> }, + /// SASL authentication has completed with additional mechanism-specific data for the client. + SaslFinal, +} + +impl Authentication { + pub fn read(mut buf: &[u8]) -> crate::Result { + Ok(match buf.get_u32::()? { + 0 => Authentication::Ok, + 2 => Authentication::KerberosV5, + 3 => Authentication::CleartextPassword, + 5 => Authentication::Md5Password, + 6 => Authentication::ScmCredential, + 7 => Authentication::Gss, + 8 => Authentication::GssContinue, + 9 => Authentication::Sspi, + 10 => Authentication::Sasl, + 11 => Authentication::SaslContinue, + 12 => Authentication::SaslFinal, + + type_ => { + return Err(protocol_err!("unknown authentication message type: {}", type_).into()); + } + }) + } } #[derive(Debug)] -pub struct SaslContinue { +pub struct AuthenticationMd5 { + pub salt: [u8; 4], +} + +impl AuthenticationMd5 { + pub fn read(mut buf: &[u8]) -> crate::Result { + let mut salt = [0_u8; 4]; + salt.copy_from_slice(buf); + + Ok(Self { salt }) + } +} + +#[derive(Debug)] +pub struct AuthenticationSasl { + pub mechanisms: Box<[Box]>, +} + +impl AuthenticationSasl { + pub fn read(mut buf: &[u8]) -> crate::Result { + let mut mechanisms = Vec::new(); + + while buf[0] != 0 { + mechanisms.push(buf.get_str_nul()?.into()); + } + + Ok(Self { + mechanisms: mechanisms.into_boxed_slice(), + }) + } +} + +#[derive(Debug)] +pub struct AuthenticationSaslContinue { pub salt: Vec, pub iter_count: u32, pub nonce: Vec, pub data: String, } -impl Decode for Authentication { - fn decode(mut buf: &[u8]) -> crate::Result { - Ok(match buf.get_u32::()? { - 0 => Authentication::Ok, +impl AuthenticationSaslContinue { + pub fn read(mut buf: &[u8]) -> crate::Result { + let mut salt: Vec = Vec::new(); + let mut nonce: Vec = Vec::new(); + let mut iter_count: u32 = 0; - 2 => Authentication::KerberosV5, + let key_value: Vec<(char, &[u8])> = buf + .split(|byte| *byte == b',') + .map(|s| { + let (key, value) = s.split_at(1); + let value = value.split_at(1).1; - 3 => Authentication::ClearTextPassword, + (key[0] as char, value) + }) + .collect(); - 5 => { - let mut salt = [0_u8; 4]; - salt.copy_from_slice(&buf); - - Authentication::Md5Password { salt } - } - - 6 => Authentication::ScmCredential, - - 7 => Authentication::Gss, - - 8 => { - let mut data = Vec::with_capacity(buf.len()); - data.extend_from_slice(buf); - - Authentication::GssContinue { - data: data.into_boxed_slice(), - } - } - - 9 => Authentication::Sspi, - - 10 => { - let mut mechanisms = Vec::new(); - - while buf[0] != 0 { - mechanisms.push(buf.get_str_nul()?.into()); + for (key, value) in key_value.iter() { + match key { + 's' => salt = value.to_vec(), + 'r' => nonce = value.to_vec(), + 'i' => { + let s = str::from_utf8(&value).map_err(|_| { + protocol_err!( + "iteration count in sasl response was not a valid utf8 string" + ) + })?; + iter_count = u32::from_str_radix(&s, 10).unwrap_or(0); } - Authentication::Sasl { - mechanisms: mechanisms.into_boxed_slice(), - } + _ => {} } + } - 11 => { - let mut salt: Vec = Vec::new(); - let mut nonce: Vec = Vec::new(); - let mut iter_count: u32 = 0; + Ok(Self { + salt: base64::decode(&salt).map_err(|_| { + protocol_err!("salt value response from postgres was not base64 encoded") + })?, + nonce, + iter_count, + data: str::from_utf8(buf) + .map_err(|_| protocol_err!("SaslContinue response was not a valid utf8 string"))? + .to_string(), + }) + } +} - let key_value: Vec<(char, &[u8])> = buf - .split(|byte| *byte == b',') - .map(|s| { - let (key, value) = s.split_at(1); - let value = value.split_at(1).1; +#[derive(Debug)] +pub struct AuthenticationSaslFinal { + pub data: Box<[u8]>, +} - (key[0] as char, value) - }) - .collect(); +impl AuthenticationSaslFinal { + pub fn read(mut buf: &[u8]) -> crate::Result { + let mut data = Vec::with_capacity(buf.len()); + data.extend_from_slice(buf); - for (key, value) in key_value.iter() { - match key { - 's' => salt = value.to_vec(), - 'r' => nonce = value.to_vec(), - 'i' => { - let s = str::from_utf8(&value).map_err(|_| { - protocol_err!( - "iteration count in sasl response was not a valid utf8 string" - ) - })?; - iter_count = u32::from_str_radix(&s, 10).unwrap_or(0); - } - - _ => {} - } - } - - Authentication::SaslContinue(SaslContinue { - salt: base64::decode(&salt).map_err(|_| { - protocol_err!("salt value response from postgres was not base64 encoded") - })?, - nonce, - iter_count, - data: str::from_utf8(buf) - .map_err(|_| { - protocol_err!("SaslContinue response was not a valid utf8 string") - })? - .to_string(), - }) - } - - 12 => { - let mut data = Vec::with_capacity(buf.len()); - data.extend_from_slice(buf); - - Authentication::SaslFinal { - data: data.into_boxed_slice(), - } - } - - id => { - return Err(protocol_err!("unknown authentication response: {}", id).into()); - } + Ok(Self { + data: data.into_boxed_slice(), }) } } @@ -158,27 +182,25 @@ impl Decode for Authentication { #[cfg(test)] mod tests { use super::{Authentication, Decode}; + use crate::postgres::protocol::authentication::AuthenticationMd5; use matches::assert_matches; const AUTH_OK: &[u8] = b"\0\0\0\0"; const AUTH_MD5: &[u8] = b"\0\0\0\x05\x93\x189\x98"; #[test] - fn it_decodes_auth_ok() { - let m = Authentication::decode(AUTH_OK).unwrap(); + fn it_reads_auth_ok() { + let m = Authentication::read(AUTH_OK).unwrap(); assert_matches!(m, Authentication::Ok); } #[test] - fn it_decodes_auth_md5_password() { - let m = Authentication::decode(AUTH_MD5).unwrap(); + fn it_reads_auth_md5_password() { + let m = Authentication::read(AUTH_MD5).unwrap(); + let data = AuthenticationMd5::read(&AUTH_MD5[4..]).unwrap(); - assert_matches!( - m, - Authentication::Md5Password { - salt: [147, 24, 57, 152] - } - ); + assert_matches!(m, Authentication::Md5Password); + assert_eq!(data.salt, [147, 24, 57, 152]); } } diff --git a/sqlx-core/src/postgres/protocol/command_complete.rs b/sqlx-core/src/postgres/protocol/command_complete.rs index 734c6343..56695787 100644 --- a/sqlx-core/src/postgres/protocol/command_complete.rs +++ b/sqlx-core/src/postgres/protocol/command_complete.rs @@ -6,8 +6,8 @@ pub struct CommandComplete { pub affected_rows: u64, } -impl Decode for CommandComplete { - fn decode(mut buf: &[u8]) -> crate::Result { +impl CommandComplete { + pub(crate) fn read(mut buf: &[u8]) -> crate::Result { // Attempt to parse the last word in the command tag as an integer // If it can't be parsed, the tag is probably "CREATE TABLE" or something // and we should return 0 rows @@ -35,29 +35,29 @@ mod tests { const COMMAND_COMPLETE_BEGIN: &[u8] = b"BEGIN\0"; #[test] - fn it_decodes_command_complete_for_insert() { - let message = CommandComplete::decode(COMMAND_COMPLETE_INSERT).unwrap(); + fn it_reads_command_complete_for_insert() { + let message = CommandComplete::read(COMMAND_COMPLETE_INSERT).unwrap(); assert_eq!(message.affected_rows, 1); } #[test] - fn it_decodes_command_complete_for_update() { - let message = CommandComplete::decode(COMMAND_COMPLETE_UPDATE).unwrap(); + fn it_reads_command_complete_for_update() { + let message = CommandComplete::read(COMMAND_COMPLETE_UPDATE).unwrap(); assert_eq!(message.affected_rows, 512); } #[test] - fn it_decodes_command_complete_for_begin() { - let message = CommandComplete::decode(COMMAND_COMPLETE_BEGIN).unwrap(); + fn it_reads_command_complete_for_begin() { + let message = CommandComplete::read(COMMAND_COMPLETE_BEGIN).unwrap(); assert_eq!(message.affected_rows, 0); } #[test] - fn it_decodes_command_complete_for_create_table() { - let message = CommandComplete::decode(COMMAND_COMPLETE_CREATE_TABLE).unwrap(); + fn it_reads_command_complete_for_create_table() { + let message = CommandComplete::read(COMMAND_COMPLETE_CREATE_TABLE).unwrap(); assert_eq!(message.affected_rows, 0); } diff --git a/sqlx-core/src/postgres/protocol/data_row.rs b/sqlx-core/src/postgres/protocol/data_row.rs index 1fb6ad72..a5148ebf 100644 --- a/sqlx-core/src/postgres/protocol/data_row.rs +++ b/sqlx-core/src/postgres/protocol/data_row.rs @@ -1,34 +1,49 @@ use crate::io::{Buf, ByteStr}; use crate::postgres::protocol::Decode; +use crate::postgres::PgConnection; use byteorder::NetworkEndian; use std::fmt::{self, Debug}; use std::ops::Range; pub struct DataRow { - buffer: Box<[u8]>, - values: Box<[Option>]>, + len: u16, } impl DataRow { pub fn len(&self) -> usize { - self.values.len() + self.len as usize } - pub fn get(&self, index: usize) -> Option<&[u8]> { - let range = self.values[index].as_ref()?; + pub fn get<'a>( + &self, + buffer: &'a [u8], + values: &[Option>], + index: usize, + ) -> Option<&'a [u8]> { + let range = values[index].as_ref()?; - Some(&self.buffer[(range.start as usize)..(range.end as usize)]) + Some(&buffer[(range.start as usize)..(range.end as usize)]) } } -impl Decode for DataRow { - fn decode(mut buf: &[u8]) -> crate::Result { - let len = buf.get_u16::()? as usize; - let buffer: Box<[u8]> = buf.into(); - let mut values = Vec::with_capacity(len); - let mut index = 4; +impl DataRow { + pub(crate) fn read<'a>( + connection: &mut PgConnection, + // buffer: &'a [u8], + // values: &'a mut Vec>>, + ) -> crate::Result { + let buffer = connection.stream.buffer(); + let values = &mut connection.data_row_values_buf; - while values.len() < len { + values.clear(); + + let mut buf = buffer; + + let len = buf.get_u16::()?; + + let mut index = 6; + + while values.len() < (len as usize) { // The length of the column value, in bytes (this count does not include itself). // Can be zero. As a special case, -1 indicates a NULL column value. // No value bytes follow in the NULL case. @@ -46,26 +61,7 @@ impl Decode for DataRow { } } - Ok(Self { - values: values.into_boxed_slice(), - buffer, - }) - } -} - -impl Debug for DataRow { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "DataRow(")?; - - let len = self.values.len(); - - f.debug_list() - .entries((0..len).map(|i| self.get(i).map(ByteStr))) - .finish()?; - - write!(f, ")")?; - - Ok(()) + Ok(Self { len }) } } @@ -76,18 +72,14 @@ mod tests { const DATA_ROW: &[u8] = b"\0\x03\0\0\0\x011\0\0\0\x012\0\0\0\x013"; #[test] - fn it_decodes_data_row() { - let m = DataRow::decode(DATA_ROW).unwrap(); + fn it_reads_data_row() { + let mut values = Vec::new(); + let m = DataRow::read(DATA_ROW, &mut values).unwrap(); - assert_eq!(m.values.len(), 3); + assert_eq!(m.len, 3); - assert_eq!(m.get(0), Some(&b"1"[..])); - assert_eq!(m.get(1), Some(&b"2"[..])); - assert_eq!(m.get(2), Some(&b"3"[..])); - - assert_eq!( - format!("{:?}", m), - "DataRow([Some(b\"1\"), Some(b\"2\"), Some(b\"3\")])" - ); + assert_eq!(m.get(DATA_ROW, &values, 0), Some(&b"1"[..])); + assert_eq!(m.get(DATA_ROW, &values, 1), Some(&b"2"[..])); + assert_eq!(m.get(DATA_ROW, &values, 2), Some(&b"3"[..])); } } diff --git a/sqlx-core/src/postgres/protocol/message.rs b/sqlx-core/src/postgres/protocol/message.rs index a90b7021..dcc4a01b 100644 --- a/sqlx-core/src/postgres/protocol/message.rs +++ b/sqlx-core/src/postgres/protocol/message.rs @@ -1,24 +1,57 @@ +use std::convert::TryFrom; + use crate::postgres::protocol::{ Authentication, BackendKeyData, CommandComplete, DataRow, NotificationResponse, ParameterDescription, ParameterStatus, ReadyForQuery, Response, RowDescription, }; -#[derive(Debug)] +#[derive(Debug, Copy, Clone)] #[repr(u8)] pub enum Message { - Authentication(Box), - ParameterStatus(Box), - BackendKeyData(BackendKeyData), - ReadyForQuery(ReadyForQuery), - CommandComplete(CommandComplete), - DataRow(DataRow), - Response(Box), - NotificationResponse(Box), - ParseComplete, + Authentication, + BackendKeyData, BindComplete, CloseComplete, + CommandComplete, + DataRow, NoData, + NotificationResponse, + ParameterDescription, + ParameterStatus, + ParseComplete, PortalSuspended, - ParameterDescription(Box), - RowDescription(Box), + ReadyForQuery, + NoticeResponse, + ErrorResponse, + RowDescription, +} + +impl TryFrom for Message { + type Error = crate::Error; + + fn try_from(type_: u8) -> crate::Result { + // https://www.postgresql.org/docs/12/protocol-message-formats.html + Ok(match type_ { + b'E' => Message::ErrorResponse, + b'N' => Message::NoticeResponse, + b'D' => Message::DataRow, + b'S' => Message::ParameterStatus, + b'Z' => Message::ReadyForQuery, + b'R' => Message::Authentication, + b'K' => Message::BackendKeyData, + b'C' => Message::CommandComplete, + b'A' => Message::NotificationResponse, + b'1' => Message::ParseComplete, + b'2' => Message::BindComplete, + b'3' => Message::CloseComplete, + b'n' => Message::NoData, + b's' => Message::PortalSuspended, + b't' => Message::ParameterDescription, + b'T' => Message::RowDescription, + + id => { + return Err(protocol_err!("unknown message: {:?}", id).into()); + } + }) + } } diff --git a/sqlx-core/src/postgres/protocol/mod.rs b/sqlx-core/src/postgres/protocol/mod.rs index 324d05e4..7669a170 100644 --- a/sqlx-core/src/postgres/protocol/mod.rs +++ b/sqlx-core/src/postgres/protocol/mod.rs @@ -58,7 +58,10 @@ mod row_description; mod message; -pub use authentication::Authentication; +pub use authentication::{ + Authentication, AuthenticationMd5, AuthenticationSasl, AuthenticationSaslContinue, + AuthenticationSaslFinal, +}; pub use backend_key_data::BackendKeyData; pub use command_complete::CommandComplete; pub use data_row::DataRow; diff --git a/sqlx-core/src/postgres/protocol/parameter_description.rs b/sqlx-core/src/postgres/protocol/parameter_description.rs index ed30c9df..1dd9c119 100644 --- a/sqlx-core/src/postgres/protocol/parameter_description.rs +++ b/sqlx-core/src/postgres/protocol/parameter_description.rs @@ -7,8 +7,8 @@ pub struct ParameterDescription { pub ids: Box<[TypeId]>, } -impl Decode for ParameterDescription { - fn decode(mut buf: &[u8]) -> crate::Result { +impl ParameterDescription { + pub(crate) fn read(mut buf: &[u8]) -> crate::Result { let cnt = buf.get_u16::()? as usize; let mut ids = Vec::with_capacity(cnt); @@ -27,9 +27,9 @@ mod test { use super::{Decode, ParameterDescription}; #[test] - fn it_decodes_parameter_description() { + fn it_reads_parameter_description() { let buf = b"\x00\x02\x00\x00\x00\x00\x00\x00\x05\x00"; - let desc = ParameterDescription::decode(buf).unwrap(); + let desc = ParameterDescription::read(buf).unwrap(); assert_eq!(desc.ids.len(), 2); assert_eq!(desc.ids[0].0, 0x0000_0000); @@ -37,9 +37,9 @@ mod test { } #[test] - fn it_decodes_empty_parameter_description() { + fn it_reads_empty_parameter_description() { let buf = b"\x00\x00"; - let desc = ParameterDescription::decode(buf).unwrap(); + let desc = ParameterDescription::read(buf).unwrap(); assert_eq!(desc.ids.len(), 0); } diff --git a/sqlx-core/src/postgres/protocol/response.rs b/sqlx-core/src/postgres/protocol/response.rs index 92be34b2..d01eb6aa 100644 --- a/sqlx-core/src/postgres/protocol/response.rs +++ b/sqlx-core/src/postgres/protocol/response.rs @@ -65,8 +65,8 @@ pub struct Response { pub routine: Option>, } -impl Decode for Response { - fn decode(mut buf: &[u8]) -> crate::Result { +impl Response { + pub fn read(mut buf: &[u8]) -> crate::Result { let mut code = None::>; let mut message = None::>; let mut severity = None::>; diff --git a/sqlx-core/src/postgres/protocol/row_description.rs b/sqlx-core/src/postgres/protocol/row_description.rs index 02d8b3d7..eb9b8aef 100644 --- a/sqlx-core/src/postgres/protocol/row_description.rs +++ b/sqlx-core/src/postgres/protocol/row_description.rs @@ -18,8 +18,8 @@ pub struct Field { pub type_format: TypeFormat, } -impl Decode for RowDescription { - fn decode(mut buf: &[u8]) -> crate::Result { +impl RowDescription { + pub(crate) fn read(mut buf: &[u8]) -> crate::Result { let cnt = buf.get_u16::()? as usize; let mut fields = Vec::with_capacity(cnt); @@ -57,7 +57,7 @@ mod test { use super::{Decode, RowDescription}; #[test] - fn it_decodes_row_description() { + fn it_reads_row_description() { #[rustfmt::skip] let buf = bytes! { // Number of Parameters @@ -82,7 +82,7 @@ mod test { 0_u8, 0_u8 // format_code }; - let desc = RowDescription::decode(&buf).unwrap(); + let desc = RowDescription::read(&buf).unwrap(); assert_eq!(desc.fields.len(), 2); assert_eq!(desc.fields[0].type_id.0, 0x0000_0000); @@ -90,9 +90,9 @@ mod test { } #[test] - fn it_decodes_empty_row_description() { + fn it_reads_empty_row_description() { let buf = b"\x00\x00"; - let desc = RowDescription::decode(buf).unwrap(); + let desc = RowDescription::read(buf).unwrap(); assert_eq!(desc.fields.len(), 0); } diff --git a/sqlx-core/src/postgres/row.rs b/sqlx-core/src/postgres/row.rs index 53b92ae1..c6d06a20 100644 --- a/sqlx-core/src/postgres/row.rs +++ b/sqlx-core/src/postgres/row.rs @@ -1,58 +1,58 @@ use std::collections::HashMap; use std::sync::Arc; +use crate::connection::MaybeOwnedConnection; use crate::decode::Decode; +use crate::pool::PoolConnection; use crate::postgres::protocol::DataRow; -use crate::postgres::Postgres; +use crate::postgres::{PgConnection, Postgres}; use crate::row::{Row, RowIndex}; -use crate::types::HasSqlType; +use crate::types::Type; -pub struct PgRow { +pub struct PgRow<'c> { + pub(super) connection: MaybeOwnedConnection<'c, PgConnection>, pub(super) data: DataRow, pub(super) columns: Arc, usize>>, } -impl Row for PgRow { +impl<'c> Row<'c> for PgRow<'c> { type Database = Postgres; fn len(&self) -> usize { self.data.len() } - fn get(&self, index: I) -> T + fn try_get_raw<'i, I>(&'c self, index: I) -> crate::Result> where - Self::Database: HasSqlType, - I: RowIndex, - T: Decode, + I: RowIndex<'c, Self> + 'i, { - index.try_get(self).unwrap() + index.try_get_raw(self) } } -impl RowIndex for usize { - fn try_get(&self, row: &PgRow) -> crate::Result - where - ::Database: HasSqlType, - T: Decode<::Database>, - { - Ok(Decode::decode_nullable(row.data.get(*self))?) +impl<'c> RowIndex<'c, PgRow<'c>> for usize { + fn try_get_raw(self, row: &'c PgRow<'c>) -> crate::Result> { + Ok(row.data.get( + row.connection.stream.buffer(), + &row.connection.data_row_values_buf, + self, + )) } } -impl RowIndex for &'_ str { - fn try_get(&self, row: &PgRow) -> crate::Result - where - ::Database: HasSqlType, - T: Decode<::Database>, - { - let index = row - .columns - .get(*self) - .ok_or_else(|| crate::Error::ColumnNotFound((*self).into()))?; - let value = Decode::decode_nullable(row.data.get(*index))?; +// impl<'c> RowIndex<'c, PgRow<'c>> for &'_ str { +// fn try_get_raw(self, row: &'r PgRow<'c>) -> crate::Result> { +// let index = row +// .columns +// .get(self) +// .ok_or_else(|| crate::Error::ColumnNotFound((*self).into()))?; +// +// Ok(row.data.get( +// row.connection.stream.buffer(), +// &row.connection.data_row_values_buf, +// *index, +// )) +// } +// } - Ok(value) - } -} - -impl_from_row_for_row!(PgRow); +// TODO: impl_from_row_for_row!(PgRow); diff --git a/sqlx-core/src/postgres/sasl.rs b/sqlx-core/src/postgres/sasl.rs index 27f5c4ea..73df9157 100644 --- a/sqlx-core/src/postgres/sasl.rs +++ b/sqlx-core/src/postgres/sasl.rs @@ -3,8 +3,10 @@ use rand::Rng; use sha2::{Digest, Sha256}; use crate::postgres::protocol::{ - hi, Authentication, Encode, Message, SaslInitialResponse, SaslResponse, + hi, Authentication, AuthenticationSaslContinue, Encode, Message, SaslInitialResponse, + SaslResponse, }; +use crate::postgres::stream::PgStream; use crate::postgres::PgConnection; static GS2_HEADER: &'static str = "n,,"; @@ -43,7 +45,7 @@ fn nonce() -> String { // Performs authenticiton using Simple Authentication Security Layer (SASL) which is what // Postgres uses pub(super) async fn authenticate>( - conn: &mut PgConnection, + stream: &mut PgStream, username: T, password: T, ) -> crate::Result<()> { @@ -62,13 +64,18 @@ pub(super) async fn authenticate>( client_first_message_bare = client_first_message_bare ); - SaslInitialResponse(&client_first_message).encode(conn.stream.buffer_mut()); - conn.stream.flush().await?; + stream.write(SaslInitialResponse(&client_first_message)); + stream.flush().await?; - let server_first_message = conn.receive().await?; + let server_first_message = stream.read().await?; + + if let Message::Authentication = server_first_message { + let auth = Authentication::read(stream.buffer())?; + + if let Authentication::SaslContinue = auth { + // todo: better way to indicate that we consumed just these 4 bytes? + let sasl = AuthenticationSaslContinue::read(&stream.buffer()[4..])?; - if let Some(Message::Authentication(auth)) = server_first_message { - if let Authentication::SaslContinue(sasl) = *auth { let server_first_message = sasl.data; // SaltedPassword := Hi(Normalize(password), salt, i) @@ -132,9 +139,11 @@ pub(super) async fn authenticate>( client_proof = base64::encode(&client_proof) ); - SaslResponse(&client_final_message).encode(conn.stream.buffer_mut()); - conn.stream.flush().await?; - let _server_final_response = conn.receive().await?; + stream.write(SaslResponse(&client_final_message)); + stream.flush().await?; + + let _server_final_response = stream.read().await?; + // todo: assert that this was SaslFinal? Ok(()) } else { diff --git a/sqlx-core/src/postgres/stream.rs b/sqlx-core/src/postgres/stream.rs new file mode 100644 index 00000000..671e3fe5 --- /dev/null +++ b/sqlx-core/src/postgres/stream.rs @@ -0,0 +1,90 @@ +use std::convert::TryInto; +use std::net::Shutdown; + +use byteorder::NetworkEndian; + +use crate::io::{Buf, BufStream, MaybeTlsStream}; +use crate::postgres::protocol::{Encode, Message, Response}; +use crate::postgres::PgError; +use crate::url::Url; + +pub struct PgStream { + stream: BufStream, + + // Most recently received message + // Is referenced by our buffered stream + // Is initialized to ReadyForQuery/0 at the start + message: (Message, u32), +} + +impl PgStream { + pub(super) async fn new(url: &Url) -> crate::Result { + let stream = MaybeTlsStream::connect(&url, 5432).await?; + + Ok(Self { + stream: BufStream::new(stream), + message: (Message::ReadyForQuery, 0), + }) + } + + pub(super) fn shutdown(&self) -> crate::Result<()> { + Ok(self.stream.shutdown(Shutdown::Both)?) + } + + #[inline] + pub(super) fn write(&mut self, message: M) + where + M: Encode, + { + message.encode(self.stream.buffer_mut()); + } + + #[inline] + pub(super) async fn flush(&mut self) -> crate::Result<()> { + Ok(self.stream.flush().await?) + } + + pub(super) async fn read(&mut self) -> crate::Result { + // https://www.postgresql.org/docs/12/protocol-overview.html#PROTOCOL-MESSAGE-CONCEPTS + + // All communication is through a stream of messages. The first byte of a message + // identifies the message type, and the next four bytes specify the length of the rest of + // the message (this length count includes itself, but not the message-type byte). + + if self.message.1 > 0 { + // If there is any data in our read buffer we need to make sure we flush that + // so reading will return the *next* message + self.stream.consume(self.message.1 as usize); + } + + let mut header = self.stream.peek(4 + 1).await?; + + let type_ = header.get_u8()?.try_into()?; + let length = header.get_u32::()? - 4; + + self.message = (type_, length); + self.stream.consume(4 + 1); + + // Wait until there is enough data in the stream. We then return without actually + // inspecting the data. This is then looked at later through the [buffer] function + let _ = self.stream.peek(length as usize).await?; + + if let Message::ErrorResponse = type_ { + // This is an error, bubble up as one immediately + return Err(crate::Error::Database(Box::new(PgError(Response::read( + self.stream.buffer(), + )?)))); + } + + Ok(type_) + } + + /// Returns a reference to the internally buffered message. + /// + /// This is the body of the message identified by the most recent call + /// to `read`. + #[inline] + pub(super) fn buffer(&self) -> &[u8] { + &self.stream.buffer()[..(self.message.1 as usize)] + } +} diff --git a/sqlx-core/src/postgres/types/bool.rs b/sqlx-core/src/postgres/types/bool.rs index 56e220cf..e3932fd0 100644 --- a/sqlx-core/src/postgres/types/bool.rs +++ b/sqlx-core/src/postgres/types/bool.rs @@ -3,15 +3,15 @@ use crate::encode::Encode; use crate::postgres::protocol::TypeId; use crate::postgres::types::PgTypeInfo; use crate::postgres::Postgres; -use crate::types::HasSqlType; +use crate::types::Type; -impl HasSqlType for Postgres { +impl Type for bool { fn type_info() -> PgTypeInfo { PgTypeInfo::new(TypeId::BOOL) } } -impl HasSqlType<[bool]> for Postgres { +impl Type for [bool] { fn type_info() -> PgTypeInfo { PgTypeInfo::new(TypeId::ARRAY_BOOL) } diff --git a/sqlx-core/src/postgres/types/bytes.rs b/sqlx-core/src/postgres/types/bytes.rs index 0c06b908..f64ac337 100644 --- a/sqlx-core/src/postgres/types/bytes.rs +++ b/sqlx-core/src/postgres/types/bytes.rs @@ -3,24 +3,24 @@ use crate::encode::Encode; use crate::postgres::protocol::TypeId; use crate::postgres::types::PgTypeInfo; use crate::postgres::Postgres; -use crate::types::HasSqlType; +use crate::types::Type; -impl HasSqlType<[u8]> for Postgres { +impl Type for [u8] { fn type_info() -> PgTypeInfo { PgTypeInfo::new(TypeId::BYTEA) } } -impl HasSqlType<[&'_ [u8]]> for Postgres { +impl Type for [&'_ [u8]] { fn type_info() -> PgTypeInfo { PgTypeInfo::new(TypeId::ARRAY_BYTEA) } } // TODO: Do we need the [HasSqlType] here on the Vec? -impl HasSqlType> for Postgres { +impl Type for Vec { fn type_info() -> PgTypeInfo { - >::type_info() + <[u8] as Type>::type_info() } } diff --git a/sqlx-core/src/postgres/types/chrono.rs b/sqlx-core/src/postgres/types/chrono.rs index d73d1d89..cf06f2b4 100644 --- a/sqlx-core/src/postgres/types/chrono.rs +++ b/sqlx-core/src/postgres/types/chrono.rs @@ -8,27 +8,27 @@ use crate::encode::Encode; use crate::postgres::protocol::TypeId; use crate::postgres::types::PgTypeInfo; use crate::postgres::Postgres; -use crate::types::HasSqlType; +use crate::types::Type; -impl HasSqlType for Postgres { +impl Type for NaiveTime { fn type_info() -> PgTypeInfo { PgTypeInfo::new(TypeId::TIME) } } -impl HasSqlType for Postgres { +impl Type for NaiveDate { fn type_info() -> PgTypeInfo { PgTypeInfo::new(TypeId::DATE) } } -impl HasSqlType for Postgres { +impl Type for NaiveDateTime { fn type_info() -> PgTypeInfo { PgTypeInfo::new(TypeId::TIMESTAMP) } } -impl HasSqlType> for Postgres +impl Type> for Postgres where Tz: TimeZone, { @@ -37,25 +37,25 @@ where } } -impl HasSqlType<[NaiveTime]> for Postgres { +impl Type for [NaiveTime] { fn type_info() -> PgTypeInfo { PgTypeInfo::new(TypeId::ARRAY_TIME) } } -impl HasSqlType<[NaiveDate]> for Postgres { +impl Type for [NaiveDate] { fn type_info() -> PgTypeInfo { PgTypeInfo::new(TypeId::ARRAY_DATE) } } -impl HasSqlType<[NaiveDateTime]> for Postgres { +impl Type for [NaiveDateTime] { fn type_info() -> PgTypeInfo { PgTypeInfo::new(TypeId::ARRAY_TIMESTAMP) } } -impl HasSqlType<[DateTime]> for Postgres +impl Type<[DateTime]> for Postgres where Tz: TimeZone, { diff --git a/sqlx-core/src/postgres/types/float.rs b/sqlx-core/src/postgres/types/float.rs index 48539d98..fb50cca4 100644 --- a/sqlx-core/src/postgres/types/float.rs +++ b/sqlx-core/src/postgres/types/float.rs @@ -3,15 +3,15 @@ use crate::encode::Encode; use crate::postgres::protocol::TypeId; use crate::postgres::types::PgTypeInfo; use crate::postgres::Postgres; -use crate::types::HasSqlType; +use crate::types::Type; -impl HasSqlType for Postgres { +impl Type for f32 { fn type_info() -> PgTypeInfo { PgTypeInfo::new(TypeId::FLOAT4) } } -impl HasSqlType<[f32]> for Postgres { +impl Type for [f32] { fn type_info() -> PgTypeInfo { PgTypeInfo::new(TypeId::ARRAY_FLOAT4) } @@ -31,13 +31,13 @@ impl Decode for f32 { } } -impl HasSqlType for Postgres { +impl Type for f64 { fn type_info() -> PgTypeInfo { PgTypeInfo::new(TypeId::FLOAT8) } } -impl HasSqlType<[f64]> for Postgres { +impl Type for [f64] { fn type_info() -> PgTypeInfo { PgTypeInfo::new(TypeId::ARRAY_FLOAT8) } diff --git a/sqlx-core/src/postgres/types/int.rs b/sqlx-core/src/postgres/types/int.rs index b2755883..39e195f3 100644 --- a/sqlx-core/src/postgres/types/int.rs +++ b/sqlx-core/src/postgres/types/int.rs @@ -5,15 +5,15 @@ use crate::encode::Encode; use crate::postgres::protocol::TypeId; use crate::postgres::types::PgTypeInfo; use crate::postgres::Postgres; -use crate::types::HasSqlType; +use crate::types::Type; -impl HasSqlType for Postgres { +impl Type for i16 { fn type_info() -> PgTypeInfo { PgTypeInfo::new(TypeId::INT2) } } -impl HasSqlType<[i16]> for Postgres { +impl Type for [i16] { fn type_info() -> PgTypeInfo { PgTypeInfo::new(TypeId::ARRAY_INT2) } @@ -31,13 +31,13 @@ impl Decode for i16 { } } -impl HasSqlType for Postgres { +impl Type for i32 { fn type_info() -> PgTypeInfo { PgTypeInfo::new(TypeId::INT4) } } -impl HasSqlType<[i32]> for Postgres { +impl Type for [i32] { fn type_info() -> PgTypeInfo { PgTypeInfo::new(TypeId::ARRAY_INT4) } @@ -55,13 +55,13 @@ impl Decode for i32 { } } -impl HasSqlType for Postgres { +impl Type for i64 { fn type_info() -> PgTypeInfo { PgTypeInfo::new(TypeId::INT8) } } -impl HasSqlType<[i64]> for Postgres { +impl Type for [i64] { fn type_info() -> PgTypeInfo { PgTypeInfo::new(TypeId::ARRAY_INT8) } diff --git a/sqlx-core/src/postgres/types/str.rs b/sqlx-core/src/postgres/types/str.rs index 56f74ff9..5bc0f66a 100644 --- a/sqlx-core/src/postgres/types/str.rs +++ b/sqlx-core/src/postgres/types/str.rs @@ -4,25 +4,25 @@ use crate::decode::{Decode, DecodeError}; use crate::encode::Encode; use crate::postgres::protocol::TypeId; use crate::postgres::types::PgTypeInfo; -use crate::types::HasSqlType; +use crate::types::Type; use crate::Postgres; -impl HasSqlType for Postgres { +impl Type for str { fn type_info() -> PgTypeInfo { PgTypeInfo::new(TypeId::TEXT) } } -impl HasSqlType<[&'_ str]> for Postgres { +impl Type for [&'_ str] { fn type_info() -> PgTypeInfo { PgTypeInfo::new(TypeId::ARRAY_TEXT) } } // TODO: Do we need [HasSqlType] on String here? -impl HasSqlType for Postgres { +impl Type for String { fn type_info() -> PgTypeInfo { - >::type_info() + >::type_info() } } diff --git a/sqlx-core/src/postgres/types/uuid.rs b/sqlx-core/src/postgres/types/uuid.rs index fecc408c..b3cb6b36 100644 --- a/sqlx-core/src/postgres/types/uuid.rs +++ b/sqlx-core/src/postgres/types/uuid.rs @@ -5,15 +5,15 @@ use crate::encode::Encode; use crate::postgres::protocol::TypeId; use crate::postgres::types::PgTypeInfo; use crate::postgres::Postgres; -use crate::types::HasSqlType; +use crate::types::Type; -impl HasSqlType for Postgres { +impl Type for Uuid { fn type_info() -> PgTypeInfo { PgTypeInfo::new(TypeId::UUID) } } -impl HasSqlType<[Uuid]> for Postgres { +impl Type for [Uuid] { fn type_info() -> PgTypeInfo { PgTypeInfo::new(TypeId::ARRAY_UUID) } diff --git a/sqlx-core/src/query.rs b/sqlx-core/src/query.rs index 3d9ef8cc..8c9138b4 100644 --- a/sqlx-core/src/query.rs +++ b/sqlx-core/src/query.rs @@ -4,68 +4,69 @@ use crate::cursor::Cursor; use crate::database::{Database, HasCursor, HasRow}; use crate::encode::Encode; use crate::executor::{Execute, Executor}; -use crate::types::HasSqlType; +use crate::types::Type; use futures_core::stream::BoxStream; use futures_util::future::ready; use futures_util::TryFutureExt; use futures_util::TryStreamExt; use std::future::Future; use std::marker::PhantomData; +use std::mem; /// Raw SQL query with bind parameters. Returned by [`query`]. -pub struct Query<'a, DB, T = ::Arguments> +pub struct Query<'q, DB, T = ::Arguments> where DB: Database, { - query: &'a str, + query: &'q str, arguments: T, database: PhantomData, } -impl<'a, DB, P> Execute<'a, DB> for Query<'a, DB, P> +impl<'q, DB, P> Execute<'q, DB> for Query<'q, DB, P> where DB: Database, P: IntoArguments + Send, { - fn into_parts(self) -> (&'a str, Option<::Arguments>) { + fn into_parts(self) -> (&'q str, Option<::Arguments>) { (self.query, Some(self.arguments.into_arguments())) } } -impl<'a, DB, P> Query<'a, DB, P> +impl<'q, DB, P> Query<'q, DB, P> where DB: Database, P: IntoArguments + Send, { - pub fn execute<'b, E>(self, executor: E) -> impl Future> + 'b + pub async fn execute<'e, E>(self, executor: E) -> crate::Result where - E: Executor<'b, Database = DB>, - 'a: 'b, + E: Executor<'e, Database = DB>, + { + executor.execute(self).await + } + + pub fn fetch<'e, E>(self, executor: E) -> >::Cursor + where + E: Executor<'e, Database = DB>, { executor.execute(self) } - pub fn fetch<'b, E>(self, executor: E) -> >::Cursor - where - E: Executor<'b, Database = DB>, - 'a: 'b, - { - executor.execute(self) - } - - pub async fn fetch_optional<'b, E>( + pub async fn fetch_optional<'e, E>( self, executor: E, - ) -> crate::Result::Row>> + ) -> crate::Result>::Row>> where - E: Executor<'b, Database = DB>, + E: Executor<'e, Database = DB>, + 'q: 'e, { executor.execute(self).first().await } - pub async fn fetch_one<'b, E>(self, executor: E) -> crate::Result<::Row> + pub async fn fetch_one<'e, E>(self, executor: E) -> crate::Result<>::Row> where - E: Executor<'b, Database = DB>, + E: Executor<'e, Database = DB>, + 'q: 'e, { self.fetch_optional(executor) .and_then(|row| match row { @@ -83,7 +84,7 @@ where /// Bind a value for use with this SQL query. pub fn bind(mut self, value: T) -> Self where - DB: HasSqlType, + T: Type, T: Encode, { self.arguments.add(value); diff --git a/sqlx-core/src/row.rs b/sqlx-core/src/row.rs index adaefeb1..780f99b1 100644 --- a/sqlx-core/src/row.rs +++ b/sqlx-core/src/row.rs @@ -2,20 +2,17 @@ use crate::database::Database; use crate::decode::Decode; -use crate::types::HasSqlType; +use crate::types::Type; -pub trait RowIndex +pub trait RowIndex<'c, R: ?Sized> where - R: Row, + R: Row<'c>, { - fn try_get(&self, row: &R) -> crate::Result - where - R::Database: HasSqlType, - T: Decode; + fn try_get_raw(self, row: &'c R) -> crate::Result>; } /// Represents a single row of the result set. -pub trait Row: Unpin + Send + 'static { +pub trait Row<'c>: Unpin + Send { type Database: Database + ?Sized; /// Returns `true` if the row contains no values. @@ -26,18 +23,34 @@ pub trait Row: Unpin + Send + 'static { /// Returns the number of values in the row. fn len(&self) -> usize; - /// Returns the value at the `index`; can either be an integer ordinal or a column name. - fn get(&self, index: I) -> T + fn get(&'c self, index: I) -> T where - Self::Database: HasSqlType, - I: RowIndex, - T: Decode; + T: Type, + I: RowIndex<'c, Self>, + T: Decode, + { + // todo: use expect with a proper message + self.try_get(index).unwrap() + } + + fn try_get(&'c self, index: I) -> crate::Result + where + T: Type, + I: RowIndex<'c, Self>, + T: Decode, + { + Ok(Decode::decode_nullable(self.try_get_raw(index)?)?) + } + + fn try_get_raw<'i, I>(&'c self, index: I) -> crate::Result> + where + I: RowIndex<'c, Self> + 'i; } /// A **record** that can be built from a row returned from by the database. -pub trait FromRow +pub trait FromRow<'a, R> where - R: Row, + R: Row<'a>, { fn from_row(row: R) -> Self; } diff --git a/sqlx-core/src/transaction.rs b/sqlx-core/src/transaction.rs index 1ce2a5da..799d4173 100644 --- a/sqlx-core/src/transaction.rs +++ b/sqlx-core/src/transaction.rs @@ -4,6 +4,7 @@ use futures_core::future::BoxFuture; use crate::connection::Connection; use crate::database::HasCursor; +use crate::describe::Describe; use crate::executor::{Execute, Executor}; use crate::runtime::spawn; use crate::Database; @@ -19,10 +20,11 @@ where depth: u32, } -impl Transaction +impl Transaction where - T: Connection, - T: Executor<'static>, + T: Connection, + DB: Database, + T: Executor<'static, Database = DB>, { pub(crate) async fn new(depth: u32, mut inner: T) -> crate::Result { if depth == 0 { @@ -98,10 +100,11 @@ where } } -impl Connection for Transaction +impl Connection for Transaction where - T: Connection, - T: Executor<'static>, + T: Connection, + DB: Database, + T: Executor<'static, Database = DB>, { type Database = ::Database; @@ -109,9 +112,23 @@ where fn close(self) -> BoxFuture<'static, crate::Result<()>> { Box::pin(async move { self.rollback().await?.close().await }) } + + #[inline] + fn ping(&mut self) -> BoxFuture> { + Box::pin(self.deref_mut().ping()) + } + + #[doc(hidden)] + #[inline] + fn describe<'e, 'q: 'e>( + &'e mut self, + query: &'q str, + ) -> BoxFuture<'e, crate::Result>> { + Box::pin(self.deref_mut().describe(query)) + } } -impl<'a, DB, T> Executor<'a> for &'a mut Transaction +impl<'c, DB, T> Executor<'c> for &'c mut Transaction where DB: Database, T: Connection, @@ -119,19 +136,19 @@ where { type Database = ::Database; - fn execute<'b, E>(self, query: E) -> <::Database as HasCursor<'a>>::Cursor + fn execute<'q, E>(self, query: E) -> <::Database as HasCursor<'c, 'q>>::Cursor where - E: Execute<'b, Self::Database>, + E: Execute<'q, Self::Database>, { (**self).execute_by_ref(query) } - fn execute_by_ref<'b, 'c, E>( - &'c mut self, + fn execute_by_ref<'q, 'e, E>( + &'e mut self, query: E, - ) -> >::Cursor + ) -> >::Cursor where - E: Execute<'b, Self::Database>, + E: Execute<'q, Self::Database>, { (**self).execute_by_ref(query) } diff --git a/sqlx-core/src/types.rs b/sqlx-core/src/types.rs index b8371a94..d2db88f6 100644 --- a/sqlx-core/src/types.rs +++ b/sqlx-core/src/types.rs @@ -21,29 +21,34 @@ pub trait TypeInfo: Debug + Display + Clone { } /// Indicates that a SQL type is supported for a database. -pub trait HasSqlType: Database { +pub trait Type +where + DB: Database, +{ /// Returns the canonical type information on the database for the type `T`. - fn type_info() -> Self::TypeInfo; + fn type_info() -> DB::TypeInfo; } // For references to types in Rust, the underlying SQL type information // is equivalent -impl HasSqlType<&'_ T> for DB +impl Type for &'_ T where - DB: HasSqlType, + DB: Database, + T: Type, { - fn type_info() -> Self::TypeInfo { - >::type_info() + fn type_info() -> DB::TypeInfo { + >::type_info() } } // For optional types in Rust, the underlying SQL type information // is equivalent -impl HasSqlType> for DB +impl Type for Option where - DB: HasSqlType, + DB: Database, + T: Type, { - fn type_info() -> Self::TypeInfo { - >::type_info() + fn type_info() -> DB::TypeInfo { + >::type_info() } } diff --git a/sqlx-macros/src/database/mod.rs b/sqlx-macros/src/database/mod.rs index 08bde2f8..3696a690 100644 --- a/sqlx-macros/src/database/mod.rs +++ b/sqlx-macros/src/database/mod.rs @@ -32,7 +32,7 @@ macro_rules! impl_database_ext { $( // `if` statements cannot have attributes but these can $(#[$meta])? - _ if sqlx::types::TypeInfo::compatible(&<$database as sqlx::types::HasSqlType<$ty>>::type_info(), &info) => Some(input_ty!($ty $(, $input)?)), + _ if sqlx::types::TypeInfo::compatible(&<$ty as sqlx::types::Type<$database>>::type_info(), &info) => Some(input_ty!($ty $(, $input)?)), )* _ => None } @@ -42,7 +42,7 @@ macro_rules! impl_database_ext { match () { $( $(#[$meta])? - _ if sqlx::types::TypeInfo::compatible(&<$database as sqlx::types::HasSqlType<$ty>>::type_info(), &info) => return Some(stringify!($ty)), + _ if sqlx::types::TypeInfo::compatible(&<$ty as sqlx::types::Type<$database>>::type_info(), &info) => return Some(stringify!($ty)), )* _ => None } diff --git a/src/lib.rs b/src/lib.rs index 2f982871..d029c984 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,7 +12,7 @@ pub use sqlx_core::{arguments, describe, error, pool, row, types}; // Types pub use sqlx_core::{ - Connect, Connection, Database, Error, Executor, FromRow, Pool, Query, QueryAs, Result, Row, + Connect, Connection, Cursor, Database, Error, Executor, FromRow, Pool, Query, QueryAs, Result, Row, Transaction, }; diff --git a/tests/postgres-types.rs b/tests/postgres-types.rs index 177c8e0a..7f5bc393 100644 --- a/tests/postgres-types.rs +++ b/tests/postgres-types.rs @@ -1,59 +1,59 @@ -use sqlx::{postgres::PgConnection, Connection as _, Row}; - -async fn connect() -> anyhow::Result { - Ok(PgConnection::open(dotenv::var("DATABASE_URL")?).await?) -} - -macro_rules! test { - ($name:ident: $ty:ty: $($text:literal == $value:expr),+) => { - #[cfg_attr(feature = "runtime-async-std", async_std::test)] -#[cfg_attr(feature = "runtime-tokio", tokio::test)] - async fn $name () -> anyhow::Result<()> { - let mut conn = connect().await?; - - $( - let row = sqlx::query(&format!("SELECT {} = $1, $1 as _1", $text)) - .bind($value) - .fetch_one(&mut conn) - .await?; - - assert!(row.get::(0)); - assert!($value == row.get::<$ty, _>("_1")); - )+ - - Ok(()) - } - } -} - -test!(postgres_bool: bool: "false::boolean" == false, "true::boolean" == true); - -test!(postgres_smallint: i16: "821::smallint" == 821_i16); -test!(postgres_int: i32: "94101::int" == 94101_i32); -test!(postgres_bigint: i64: "9358295312::bigint" == 9358295312_i64); - -test!(postgres_real: f32: "9419.122::real" == 9419.122_f32); -test!(postgres_double: f64: "939399419.1225182::double precision" == 939399419.1225182_f64); - -test!(postgres_text: String: "'this is foo'" == "this is foo", "''" == ""); - -#[cfg_attr(feature = "runtime-async-std", async_std::test)] -#[cfg_attr(feature = "runtime-tokio", tokio::test)] -async fn postgres_bytes() -> anyhow::Result<()> { - let mut conn = connect().await?; - - let value = b"Hello, World"; - - let row = sqlx::query("SELECT E'\\\\x48656c6c6f2c20576f726c64' = $1, $1") - .bind(&value[..]) - .fetch_one(&mut conn) - .await?; - - assert!(row.get::(0)); - - let output: Vec = row.get(1); - - assert_eq!(&value[..], &*output); - - Ok(()) -} +// use sqlx::{postgres::PgConnection, Connect as _, Connection as _, Row}; +// +// async fn connect() -> anyhow::Result { +// Ok(PgConnection::connect(dotenv::var("DATABASE_URL")?).await?) +// } +// +// macro_rules! test { +// ($name:ident: $ty:ty: $($text:literal == $value:expr),+) => { +// #[cfg_attr(feature = "runtime-async-std", async_std::test)] +// #[cfg_attr(feature = "runtime-tokio", tokio::test)] +// async fn $name () -> anyhow::Result<()> { +// let mut conn = connect().await?; +// +// $( +// let row = sqlx::query(&format!("SELECT {} = $1, $1 as _1", $text)) +// .bind($value) +// .fetch_one(&mut conn) +// .await?; +// +// assert!(row.get::(0)); +// assert!($value == row.get::<$ty, _>("_1")); +// )+ +// +// Ok(()) +// } +// } +// } +// +// test!(postgres_bool: bool: "false::boolean" == false, "true::boolean" == true); +// +// test!(postgres_smallint: i16: "821::smallint" == 821_i16); +// test!(postgres_int: i32: "94101::int" == 94101_i32); +// test!(postgres_bigint: i64: "9358295312::bigint" == 9358295312_i64); +// +// test!(postgres_real: f32: "9419.122::real" == 9419.122_f32); +// test!(postgres_double: f64: "939399419.1225182::double precision" == 939399419.1225182_f64); +// +// test!(postgres_text: String: "'this is foo'" == "this is foo", "''" == ""); +// +// #[cfg_attr(feature = "runtime-async-std", async_std::test)] +// #[cfg_attr(feature = "runtime-tokio", tokio::test)] +// async fn postgres_bytes() -> anyhow::Result<()> { +// let mut conn = connect().await?; +// +// let value = b"Hello, World"; +// +// let row = sqlx::query("SELECT E'\\\\x48656c6c6f2c20576f726c64' = $1, $1") +// .bind(&value[..]) +// .fetch_one(&mut conn) +// .await?; +// +// assert!(row.get::(0)); +// +// let output: Vec = row.get(1); +// +// assert_eq!(&value[..], &*output); +// +// Ok(()) +// } diff --git a/tests/postgres.rs b/tests/postgres.rs index f4242eea..61a0699e 100644 --- a/tests/postgres.rs +++ b/tests/postgres.rs @@ -1,5 +1,5 @@ use futures::TryStreamExt; -use sqlx::{postgres::PgConnection, Connection as _, Executor as _, Row as _}; +use sqlx::{postgres::PgConnection, Connect, Connection, Executor, Row}; use sqlx_core::postgres::PgPool; use std::time::Duration; @@ -17,58 +17,40 @@ async fn it_connects() -> anyhow::Result<()> { Ok(()) } -#[cfg_attr(feature = "runtime-async-std", async_std::test)] -#[cfg_attr(feature = "runtime-tokio", tokio::test)] -async fn it_executes() -> anyhow::Result<()> { - let mut conn = connect().await?; - - let _ = conn - .send( - r#" -CREATE TEMPORARY TABLE users (id INTEGER PRIMARY KEY); - "#, - ) - .await?; - - for index in 1..=10_i32 { - let cnt = sqlx::query("INSERT INTO users (id) VALUES ($1)") - .bind(index) - .execute(&mut conn) - .await?; - - assert_eq!(cnt, 1); - } - - let sum: i32 = sqlx::query("SELECT id FROM users") - .fetch(&mut conn) - .try_fold( - 0_i32, - |acc, x| async move { Ok(acc + x.get::("id")) }, - ) - .await?; - - assert_eq!(sum, 55); - - Ok(()) -} - -#[cfg_attr(feature = "runtime-async-std", async_std::test)] -#[cfg_attr(feature = "runtime-tokio", tokio::test)] -async fn it_remains_stable_issue_30() -> anyhow::Result<()> { - let mut conn = connect().await?; - - // This tests the internal buffer wrapping around at the end - // Specifically: https://github.com/launchbadge/sqlx/issues/30 - - let rows = sqlx::query("SELECT i, random()::text FROM generate_series(1, 1000) as i") - .fetch_all(&mut conn) - .await?; - - assert_eq!(rows.len(), 1000); - assert_eq!(rows[rows.len() - 1].get::(0), 1000); - - Ok(()) -} +// #[cfg_attr(feature = "runtime-async-std", async_std::test)] +// #[cfg_attr(feature = "runtime-tokio", tokio::test)] +// async fn it_executes() -> anyhow::Result<()> { +// let mut conn = connect().await?; +// +// let _ = conn +// .send( +// r#" +// CREATE TEMPORARY TABLE users (id INTEGER PRIMARY KEY); +// "#, +// ) +// .await?; +// +// for index in 1..=10_i32 { +// let cnt = sqlx::query("INSERT INTO users (id) VALUES ($1)") +// .bind(index) +// .execute(&mut conn) +// .await?; +// +// assert_eq!(cnt, 1); +// } +// +// let sum: i32 = sqlx::query("SELECT id FROM users") +// .fetch(&mut conn) +// .try_fold( +// 0_i32, +// |acc, x| async move { Ok(acc + x.get::("id")) }, +// ) +// .await?; +// +// assert_eq!(sum, 55); +// +// Ok(()) +// } // https://github.com/launchbadge/sqlx/issues/104 #[cfg_attr(feature = "runtime-async-std", async_std::test)] @@ -122,7 +104,7 @@ async fn pool_smoke_test() -> anyhow::Result<()> { let pool = pool.clone(); spawn(async move { loop { - if let Err(e) = sqlx::query("select 1 + 1").fetch_one(&mut &pool).await { + if let Err(e) = sqlx::query("select 1 + 1").fetch_one(&pool).await { eprintln!("pool task {} dying due to {}", i, e); break; } @@ -159,5 +141,5 @@ async fn pool_smoke_test() -> anyhow::Result<()> { async fn connect() -> anyhow::Result { let _ = dotenv::dotenv(); let _ = env_logger::try_init(); - Ok(PgConnection::open(dotenv::var("DATABASE_URL")?).await?) + Ok(PgConnection::connect(dotenv::var("DATABASE_URL")?).await?) }