postgres: rewrite protocol in more iterative and lazy fashion

This commit is contained in:
Ryan Leckey 2020-02-19 08:10:27 -08:00
parent 3795d15e1c
commit a374c18a18
60 changed files with 1586 additions and 931 deletions

9
Cargo.lock generated
View File

@ -1526,6 +1526,15 @@ dependencies = [
"uuid 0.8.1 (registry+https://github.com/rust-lang/crates.io-index)", "uuid 0.8.1 (registry+https://github.com/rust-lang/crates.io-index)",
] ]
[[package]]
name = "sqlx-example-postgres-basic"
version = "0.1.0"
dependencies = [
"anyhow 1.0.26 (registry+https://github.com/rust-lang/crates.io-index)",
"async-std 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
"sqlx 0.2.5",
]
[[package]] [[package]]
name = "sqlx-example-realworld-postgres" name = "sqlx-example-realworld-postgres"
version = "0.1.0" version = "0.1.0"

View File

@ -3,6 +3,7 @@ members = [
".", ".",
"sqlx-core", "sqlx-core",
"sqlx-macros", "sqlx-macros",
"examples/postgres/basic",
"examples/realworld-postgres", "examples/realworld-postgres",
"examples/todos-postgres", "examples/todos-postgres",
] ]

View File

@ -0,0 +1,10 @@
[package]
workspace = "../../.."
name = "sqlx-example-postgres-basic"
version = "0.1.0"
edition = "2018"
[dependencies]
async-std = { version = "1", features = [ "attributes" ] }
anyhow = "1"
sqlx = { path = "../../..", features = [ "postgres" ] }

View File

@ -0,0 +1,25 @@
use sqlx::{Connect, Connection, Cursor, Executor, PgConnection, Row};
use std::convert::TryInto;
use std::time::Instant;
#[async_std::main]
async fn main() -> anyhow::Result<()> {
let mut conn = PgConnection::connect("postgres://").await?;
let mut rows = sqlx::query("SELECT definition FROM pg_database")
.execute(&mut conn)
.await?;
// let start = Instant::now();
// while let Some(row) = cursor.next().await? {
// // let raw = row.try_get(0)?.unwrap();
//
// // println!("hai: {:?}", raw);
// }
println!("?? = {}", rows);
// conn.close().await?;
Ok(())
}

View File

@ -1,2 +0,0 @@
select * from (select (1) as id, 'Herp Derpinson' as name) accounts
where id = ?

View File

@ -2,7 +2,7 @@
use crate::database::Database; use crate::database::Database;
use crate::encode::Encode; use crate::encode::Encode;
use crate::types::HasSqlType; use crate::types::Type;
/// A tuple of arguments to be sent to the database. /// A tuple of arguments to be sent to the database.
pub trait Arguments: Send + Sized + Default + 'static { pub trait Arguments: Send + Sized + Default + 'static {
@ -15,7 +15,7 @@ pub trait Arguments: Send + Sized + Default + 'static {
/// Add the value to the end of the arguments. /// Add the value to the end of the arguments.
fn add<T>(&mut self, value: T) fn add<T>(&mut self, value: T)
where where
Self::Database: HasSqlType<T>, T: Type<Self::Database>,
T: Encode<Self::Database>; T: Encode<Self::Database>;
} }

View File

@ -1,10 +1,14 @@
use std::convert::TryInto;
use std::ops::{Deref, DerefMut};
use futures_core::future::BoxFuture;
use futures_util::TryFutureExt;
use crate::database::Database; use crate::database::Database;
use crate::describe::Describe; use crate::describe::Describe;
use crate::executor::Executor; use crate::executor::Executor;
use crate::pool::{Pool, PoolConnection};
use crate::url::Url; use crate::url::Url;
use futures_core::future::BoxFuture;
use futures_util::TryFutureExt;
use std::convert::TryInto;
/// Represents a single database connection rather than a pool of database connections. /// Represents a single database connection rather than a pool of database connections.
/// ///
@ -20,20 +24,13 @@ where
fn close(self) -> BoxFuture<'static, crate::Result<()>>; fn close(self) -> BoxFuture<'static, crate::Result<()>>;
/// Verifies a connection to the database is still alive. /// Verifies a connection to the database is still alive.
fn ping(&mut self) -> BoxFuture<crate::Result<()>> fn ping(&mut self) -> BoxFuture<crate::Result<()>>;
where
for<'a> &'a mut Self: Executor<'a>,
{
Box::pin((&mut *self).execute("SELECT 1").map_ok(|_| ()))
}
#[doc(hidden)] #[doc(hidden)]
fn describe<'e, 'q: 'e>( fn describe<'e, 'q: 'e>(
&'e mut self, &'e mut self,
query: &'q str, query: &'q str,
) -> BoxFuture<'e, crate::Result<Describe<Self::Database>>> { ) -> BoxFuture<'e, crate::Result<Describe<Self::Database>>>;
todo!("make this a required function");
}
} }
/// Represents a type that can directly establish a new connection. /// Represents a type that can directly establish a new connection.
@ -44,3 +41,125 @@ pub trait Connect: Connection {
T: TryInto<Url, Error = crate::Error>, T: TryInto<Url, Error = crate::Error>,
Self: Sized; Self: Sized;
} }
mod internal {
pub enum MaybeOwnedConnection<'c, C>
where
C: super::Connect,
{
Borrowed(&'c mut C),
Owned(super::PoolConnection<C>),
}
pub enum ConnectionSource<'c, C>
where
C: super::Connect,
{
Empty,
Connection(MaybeOwnedConnection<'c, C>),
Pool(super::Pool<C>),
}
}
pub(crate) use self::internal::{ConnectionSource, MaybeOwnedConnection};
impl<'c, C> MaybeOwnedConnection<'c, C>
where
C: Connect,
{
pub(crate) fn borrow(&mut self) -> MaybeOwnedConnection<'_, C> {
match self {
MaybeOwnedConnection::Borrowed(conn) => MaybeOwnedConnection::Borrowed(&mut *conn),
MaybeOwnedConnection::Owned(ref mut conn) => MaybeOwnedConnection::Borrowed(conn),
}
}
}
impl<'c, C, DB> ConnectionSource<'c, C>
where
C: Connect<Database = DB>,
DB: Database<Connection = C>,
{
pub(crate) async fn resolve_by_ref(&mut self) -> crate::Result<MaybeOwnedConnection<'_, C>> {
if let ConnectionSource::Pool(pool) = self {
*self =
ConnectionSource::Connection(MaybeOwnedConnection::Owned(pool.acquire().await?));
}
Ok(match self {
ConnectionSource::Empty => panic!("`PgCursor` must not be used after being polled"),
ConnectionSource::Connection(conn) => conn.borrow(),
ConnectionSource::Pool(_) => unreachable!(),
})
}
pub(crate) async fn resolve(mut self) -> crate::Result<MaybeOwnedConnection<'c, C>> {
if let ConnectionSource::Pool(pool) = self {
self = ConnectionSource::Connection(MaybeOwnedConnection::Owned(pool.acquire().await?));
}
Ok(self.into_connection())
}
pub(crate) fn into_connection(self) -> MaybeOwnedConnection<'c, C> {
match self {
ConnectionSource::Connection(conn) => conn,
ConnectionSource::Empty | ConnectionSource::Pool(_) => {
panic!("`PgCursor` must not be used after being polled");
}
}
}
}
impl<C> Default for ConnectionSource<'_, C>
where
C: Connect,
{
fn default() -> Self {
ConnectionSource::Empty
}
}
impl<'c, C> From<&'c mut C> for MaybeOwnedConnection<'c, C>
where
C: Connect,
{
fn from(conn: &'c mut C) -> Self {
MaybeOwnedConnection::Borrowed(conn)
}
}
impl<'c, C> From<PoolConnection<C>> for MaybeOwnedConnection<'c, C>
where
C: Connect,
{
fn from(conn: PoolConnection<C>) -> Self {
MaybeOwnedConnection::Owned(conn)
}
}
impl<'c, C> Deref for MaybeOwnedConnection<'c, C>
where
C: Connect,
{
type Target = C;
fn deref(&self) -> &Self::Target {
match self {
MaybeOwnedConnection::Borrowed(conn) => conn,
MaybeOwnedConnection::Owned(conn) => conn,
}
}
}
impl<'c, C> DerefMut for MaybeOwnedConnection<'c, C>
where
C: Connect,
{
fn deref_mut(&mut self) -> &mut Self::Target {
match self {
MaybeOwnedConnection::Borrowed(conn) => conn,
MaybeOwnedConnection::Owned(conn) => conn,
}
}
}

View File

@ -3,7 +3,10 @@ use std::future::Future;
use futures_core::future::BoxFuture; use futures_core::future::BoxFuture;
use futures_core::stream::BoxStream; use futures_core::stream::BoxStream;
use crate::connection::MaybeOwnedConnection;
use crate::database::{Database, HasRow}; use crate::database::{Database, HasRow};
use crate::executor::Execute;
use crate::{Connect, Pool};
/// Represents a result set, which is generated by executing a query against the database. /// Represents a result set, which is generated by executing a query against the database.
/// ///
@ -13,7 +16,7 @@ use crate::database::{Database, HasRow};
/// Initially the `Cursor` is positioned before the first row. The `next` method moves the cursor /// Initially the `Cursor` is positioned before the first row. The `next` method moves the cursor
/// to the next row, and because it returns `None` when there are no more rows, it can be used /// to the next row, and because it returns `None` when there are no more rows, it can be used
/// in a `while` loop to iterate through all returned rows. /// in a `while` loop to iterate through all returned rows.
pub trait Cursor<'a> pub trait Cursor<'c, 'q>
where where
Self: Send, Self: Send,
// `.await`-ing a cursor will return the affected rows from the query // `.await`-ing a cursor will return the affected rows from the query
@ -21,16 +24,59 @@ where
{ {
type Database: Database; type Database: Database;
/// Fetch the first row in the result. Returns `None` if no row is present. // Construct the [Cursor] from a [Pool]
/// // Meant for internal use only
/// Returns `Error::MoreThanOneRow` if more than one row is in the result. // TODO: Anyone have any better ideas on how to instantiate cursors generically from a pool?
fn first(self) -> BoxFuture<'a, crate::Result<Option<<Self::Database as HasRow>::Row>>>; #[doc(hidden)]
fn from_pool<E>(pool: &Pool<<Self::Database as Database>::Connection>, query: E) -> Self
where
Self: Sized,
E: Execute<'q, Self::Database>;
#[doc(hidden)]
fn from_connection<E, C>(conn: C, query: E) -> Self
where
Self: Sized,
<Self::Database as Database>::Connection: Connect,
// MaybeOwnedConnection<'c, <Self::Database as Database>::Connection>:
// Connect<Database = Self::Database>,
C: Into<MaybeOwnedConnection<'c, <Self::Database as Database>::Connection>>,
E: Execute<'q, Self::Database>;
#[doc(hidden)]
fn first(self) -> BoxFuture<'c, crate::Result<Option<<Self::Database as HasRow<'c>>::Row>>>
where
'q: 'c;
/// Fetch the next row in the result. Returns `None` if there are no more rows. /// Fetch the next row in the result. Returns `None` if there are no more rows.
fn next(&mut self) -> BoxFuture<crate::Result<Option<<Self::Database as HasRow>::Row>>>; fn next(&mut self) -> BoxFuture<crate::Result<Option<<Self::Database as HasRow>::Row>>>;
/// Map the `Row`s in this result to a different type, returning a [`Stream`] of the results. /// Map the `Row`s in this result to a different type, returning a [`Stream`] of the results.
fn map<T, F>(self, f: F) -> BoxStream<'a, crate::Result<T>> fn map<T, F>(self, f: F) -> BoxStream<'c, crate::Result<T>>
where where
F: Fn(<Self::Database as HasRow>::Row) -> T; F: MapRowFn<Self::Database, T>,
T: 'c + Send + Unpin,
'q: 'c;
}
pub trait MapRowFn<DB, T>
where
Self: Send + Sync + 'static,
DB: Database,
DB: for<'c> HasRow<'c>,
{
fn call(&self, row: <DB as HasRow>::Row) -> T;
}
impl<DB, T, F> MapRowFn<DB, T> for F
where
DB: Database,
DB: for<'c> HasRow<'c>,
F: Send + Sync + 'static,
F: Fn(<DB as HasRow>::Row) -> T,
{
#[inline(always)]
fn call(&self, row: <DB as HasRow>::Row) -> T {
self(row)
}
} }

View File

@ -13,9 +13,9 @@ use crate::types::TypeInfo;
pub trait Database pub trait Database
where where
Self: Sized + 'static, Self: Sized + 'static,
Self: HasRow<Database = Self>, Self: for<'a> HasRow<'a, Database = Self>,
Self: for<'a> HasRawValue<'a>, Self: for<'a> HasRawValue<'a>,
Self: for<'a> HasCursor<'a, Database = Self>, Self: for<'c, 'q> HasCursor<'c, 'q, Database = Self>,
{ {
/// The concrete `Connection` implementation for this database. /// The concrete `Connection` implementation for this database.
type Connection: Connection<Database = Self>; type Connection: Connection<Database = Self>;
@ -34,14 +34,14 @@ pub trait HasRawValue<'a> {
type RawValue; type RawValue;
} }
pub trait HasCursor<'a> { pub trait HasCursor<'c, 'q> {
type Database: Database; type Database: Database;
type Cursor: Cursor<'a, Database = Self::Database>; type Cursor: Cursor<'c, 'q, Database = Self::Database>;
} }
pub trait HasRow { pub trait HasRow<'a> {
type Database: Database; type Database: Database;
type Row: Row<Database = Self::Database>; type Row: Row<'a, Database = Self::Database>;
} }

View File

@ -4,7 +4,7 @@ use std::error::Error as StdError;
use std::fmt::{self, Display}; use std::fmt::{self, Display};
use crate::database::Database; use crate::database::Database;
use crate::types::HasSqlType; use crate::types::Type;
pub enum DecodeError { pub enum DecodeError {
/// An unexpected `NULL` was encountered while decoding. /// An unexpected `NULL` was encountered while decoding.
@ -40,7 +40,8 @@ where
impl<T, DB> Decode<DB> for Option<T> impl<T, DB> Decode<DB> for Option<T>
where where
DB: Database + HasSqlType<T>, DB: Database,
T: Type<DB>,
T: Decode<DB>, T: Decode<DB>,
{ {
fn decode(buf: &[u8]) -> Result<Self, DecodeError> { fn decode(buf: &[u8]) -> Result<Self, DecodeError> {

View File

@ -1,7 +1,7 @@
//! Types and traits for encoding values to the database. //! Types and traits for encoding values to the database.
use crate::database::Database; use crate::database::Database;
use crate::types::HasSqlType; use crate::types::Type;
use std::mem; use std::mem;
/// The return type of [Encode::encode]. /// The return type of [Encode::encode].
@ -36,7 +36,8 @@ where
impl<T: ?Sized, DB> Encode<DB> for &'_ T impl<T: ?Sized, DB> Encode<DB> for &'_ T
where where
DB: Database + HasSqlType<T>, DB: Database,
T: Type<DB>,
T: Encode<DB>, T: Encode<DB>,
{ {
fn encode(&self, buf: &mut Vec<u8>) { fn encode(&self, buf: &mut Vec<u8>) {
@ -54,7 +55,8 @@ where
impl<T, DB> Encode<DB> for Option<T> impl<T, DB> Encode<DB> for Option<T>
where where
DB: Database + HasSqlType<T>, DB: Database,
T: Type<DB>,
T: Encode<DB>, T: Encode<DB>,
{ {
fn encode(&self, buf: &mut Vec<u8>) { fn encode(&self, buf: &mut Vec<u8>) {

View File

@ -14,7 +14,7 @@ use futures_util::TryStreamExt;
/// Implementations are provided for [`&Pool`](struct.Pool.html), /// Implementations are provided for [`&Pool`](struct.Pool.html),
/// [`&mut PoolConnection`](struct.PoolConnection.html), /// [`&mut PoolConnection`](struct.PoolConnection.html),
/// and [`&mut Connection`](trait.Connection.html). /// and [`&mut Connection`](trait.Connection.html).
pub trait Executor<'a> pub trait Executor<'c>
where where
Self: Send, Self: Send,
{ {
@ -22,18 +22,18 @@ where
type Database: Database; type Database: Database;
/// Executes a query that may or may not return a result set. /// Executes a query that may or may not return a result set.
fn execute<'b, E>(self, query: E) -> <Self::Database as HasCursor<'a>>::Cursor fn execute<'q, E>(self, query: E) -> <Self::Database as HasCursor<'c, 'q>>::Cursor
where where
E: Execute<'b, Self::Database>; E: Execute<'q, Self::Database>;
#[doc(hidden)] #[doc(hidden)]
fn execute_by_ref<'b, E>(&mut self, query: E) -> <Self::Database as HasCursor<'_>>::Cursor fn execute_by_ref<'b, E>(&mut self, query: E) -> <Self::Database as HasCursor<'_, 'b>>::Cursor
where where
E: Execute<'b, Self::Database>; E: Execute<'b, Self::Database>;
} }
/// A type that may be executed against a database connection. /// A type that may be executed against a database connection.
pub trait Execute<'a, DB> pub trait Execute<'q, DB>
where where
DB: Database, DB: Database,
{ {
@ -43,15 +43,15 @@ where
/// prepare the query. Returning `Some(Default::default())` is an empty arguments object that /// prepare the query. Returning `Some(Default::default())` is an empty arguments object that
/// will be prepared (and cached) before execution. /// will be prepared (and cached) before execution.
#[doc(hidden)] #[doc(hidden)]
fn into_parts(self) -> (&'a str, Option<DB::Arguments>); fn into_parts(self) -> (&'q str, Option<DB::Arguments>);
} }
impl<'a, DB> Execute<'a, DB> for &'a str impl<'q, DB> Execute<'q, DB> for &'q str
where where
DB: Database, DB: Database,
{ {
#[inline] #[inline]
fn into_parts(self) -> (&'a str, Option<DB::Arguments>) { fn into_parts(self) -> (&'q str, Option<DB::Arguments>) {
(self, None) (self, None)
} }
} }

View File

@ -35,6 +35,11 @@ where
} }
} }
#[inline]
pub fn buffer(&self) -> &[u8] {
&self.rbuf[self.rbuf_rindex..]
}
#[inline] #[inline]
pub fn buffer_mut(&mut self) -> &mut Vec<u8> { pub fn buffer_mut(&mut self) -> &mut Vec<u8> {
&mut self.wbuf &mut self.wbuf
@ -61,7 +66,14 @@ where
self.rbuf_rindex += cnt; self.rbuf_rindex += cnt;
} }
pub async fn peek(&mut self, cnt: usize) -> io::Result<Option<&[u8]>> { pub async fn peek(&mut self, cnt: usize) -> io::Result<&[u8]> {
self.try_peek(cnt)
.await
.transpose()
.ok_or(io::ErrorKind::ConnectionAborted)?
}
pub async fn try_peek(&mut self, cnt: usize) -> io::Result<Option<&[u8]>> {
loop { loop {
// Reaching end-of-file (read 0 bytes) will continuously // Reaching end-of-file (read 0 bytes) will continuously
// return None from all future calls to read // return None from all future calls to read

View File

@ -1,5 +1,5 @@
#![recursion_limit = "256"]
#![forbid(unsafe_code)] #![forbid(unsafe_code)]
#![allow(unused)]
#![cfg_attr(docsrs, feature(doc_cfg))] #![cfg_attr(docsrs, feature(doc_cfg))]
#[macro_use] #[macro_use]
@ -52,7 +52,7 @@ pub use error::{Error, Result};
pub use connection::{Connect, Connection}; pub use connection::{Connect, Connection};
pub use cursor::Cursor; pub use cursor::Cursor;
pub use executor::Executor; pub use executor::{Execute, Executor};
pub use query::{query, Query}; pub use query::{query, Query};
pub use transaction::Transaction; pub use transaction::Transaction;
@ -71,3 +71,7 @@ pub use mysql::MySql;
#[cfg_attr(docsrs, doc(cfg(feature = "postgres")))] #[cfg_attr(docsrs, doc(cfg(feature = "postgres")))]
#[doc(inline)] #[doc(inline)]
pub use postgres::Postgres; pub use postgres::Postgres;
// Named Lifetimes:
// 'c: connection
// 'q: query string (and arguments)

View File

@ -2,7 +2,7 @@ use crate::arguments::Arguments;
use crate::encode::{Encode, IsNull}; use crate::encode::{Encode, IsNull};
use crate::mysql::types::MySqlTypeInfo; use crate::mysql::types::MySqlTypeInfo;
use crate::mysql::MySql; use crate::mysql::MySql;
use crate::types::HasSqlType; use crate::types::Type;
#[derive(Default)] #[derive(Default)]
pub struct MySqlArguments { pub struct MySqlArguments {
@ -27,10 +27,10 @@ impl Arguments for MySqlArguments {
fn add<T>(&mut self, value: T) fn add<T>(&mut self, value: T)
where where
Self::Database: HasSqlType<T>, Self::Database: Type<T>,
T: Encode<Self::Database>, T: Encode<Self::Database>,
{ {
let type_id = <MySql as HasSqlType<T>>::type_info(); let type_id = <MySql as Type<T>>::type_info();
let index = self.param_types.len(); let index = self.param_types.len();
self.param_types.push(type_id); self.param_types.push(type_id);

View File

@ -5,7 +5,7 @@ use crate::decode::Decode;
use crate::mysql::protocol; use crate::mysql::protocol;
use crate::mysql::MySql; use crate::mysql::MySql;
use crate::row::{Row, RowIndex}; use crate::row::{Row, RowIndex};
use crate::types::HasSqlType; use crate::types::Type;
pub struct MySqlRow { pub struct MySqlRow {
pub(super) row: protocol::Row, pub(super) row: protocol::Row,
@ -21,7 +21,7 @@ impl Row for MySqlRow {
fn get<T, I>(&self, index: I) -> T fn get<T, I>(&self, index: I) -> T
where where
Self::Database: HasSqlType<T>, Self::Database: Type<T>,
I: RowIndex<Self>, I: RowIndex<Self>,
T: Decode<Self::Database>, T: Decode<Self::Database>,
{ {
@ -32,7 +32,7 @@ impl Row for MySqlRow {
impl RowIndex<MySqlRow> for usize { impl RowIndex<MySqlRow> for usize {
fn try_get<T>(&self, row: &MySqlRow) -> crate::Result<T> fn try_get<T>(&self, row: &MySqlRow) -> crate::Result<T>
where where
<MySqlRow as Row>::Database: HasSqlType<T>, <MySqlRow as Row>::Database: Type<T>,
T: Decode<<MySqlRow as Row>::Database>, T: Decode<<MySqlRow as Row>::Database>,
{ {
Ok(Decode::decode_nullable(row.row.get(*self))?) Ok(Decode::decode_nullable(row.row.get(*self))?)
@ -42,7 +42,7 @@ impl RowIndex<MySqlRow> for usize {
impl RowIndex<MySqlRow> for &'_ str { impl RowIndex<MySqlRow> for &'_ str {
fn try_get<T>(&self, row: &MySqlRow) -> crate::Result<T> fn try_get<T>(&self, row: &MySqlRow) -> crate::Result<T>
where where
<MySqlRow as Row>::Database: HasSqlType<T>, <MySqlRow as Row>::Database: Type<T>,
T: Decode<<MySqlRow as Row>::Database>, T: Decode<<MySqlRow as Row>::Database>,
{ {
let index = row let index = row

View File

@ -3,9 +3,9 @@ use crate::encode::Encode;
use crate::mysql::protocol::TypeId; use crate::mysql::protocol::TypeId;
use crate::mysql::types::MySqlTypeInfo; use crate::mysql::types::MySqlTypeInfo;
use crate::mysql::MySql; use crate::mysql::MySql;
use crate::types::HasSqlType; use crate::types::Type;
impl HasSqlType<bool> for MySql { impl Type<bool> for MySql {
fn type_info() -> MySqlTypeInfo { fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::TINY_INT) MySqlTypeInfo::new(TypeId::TINY_INT)
} }

View File

@ -6,9 +6,9 @@ use crate::mysql::io::{BufExt, BufMutExt};
use crate::mysql::protocol::TypeId; use crate::mysql::protocol::TypeId;
use crate::mysql::types::MySqlTypeInfo; use crate::mysql::types::MySqlTypeInfo;
use crate::mysql::MySql; use crate::mysql::MySql;
use crate::types::HasSqlType; use crate::types::Type;
impl HasSqlType<[u8]> for MySql { impl Type<[u8]> for MySql {
fn type_info() -> MySqlTypeInfo { fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo { MySqlTypeInfo {
id: TypeId::TEXT, id: TypeId::TEXT,
@ -19,9 +19,9 @@ impl HasSqlType<[u8]> for MySql {
} }
} }
impl HasSqlType<Vec<u8>> for MySql { impl Type<Vec<u8>> for MySql {
fn type_info() -> MySqlTypeInfo { fn type_info() -> MySqlTypeInfo {
<Self as HasSqlType<[u8]>>::type_info() <Self as Type<[u8]>>::type_info()
} }
} }

View File

@ -9,9 +9,9 @@ use crate::io::{Buf, BufMut};
use crate::mysql::protocol::TypeId; use crate::mysql::protocol::TypeId;
use crate::mysql::types::MySqlTypeInfo; use crate::mysql::types::MySqlTypeInfo;
use crate::mysql::MySql; use crate::mysql::MySql;
use crate::types::HasSqlType; use crate::types::Type;
impl HasSqlType<DateTime<Utc>> for MySql { impl Type<DateTime<Utc>> for MySql {
fn type_info() -> MySqlTypeInfo { fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::TIMESTAMP) MySqlTypeInfo::new(TypeId::TIMESTAMP)
} }
@ -31,7 +31,7 @@ impl Decode<MySql> for DateTime<Utc> {
} }
} }
impl HasSqlType<NaiveTime> for MySql { impl Type<NaiveTime> for MySql {
fn type_info() -> MySqlTypeInfo { fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::TIME) MySqlTypeInfo::new(TypeId::TIME)
} }
@ -80,7 +80,7 @@ impl Decode<MySql> for NaiveTime {
} }
} }
impl HasSqlType<NaiveDate> for MySql { impl Type<NaiveDate> for MySql {
fn type_info() -> MySqlTypeInfo { fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::DATE) MySqlTypeInfo::new(TypeId::DATE)
} }
@ -104,7 +104,7 @@ impl Decode<MySql> for NaiveDate {
} }
} }
impl HasSqlType<NaiveDateTime> for MySql { impl Type<NaiveDateTime> for MySql {
fn type_info() -> MySqlTypeInfo { fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::DATETIME) MySqlTypeInfo::new(TypeId::DATETIME)
} }

View File

@ -3,7 +3,7 @@ use crate::encode::Encode;
use crate::mysql::protocol::TypeId; use crate::mysql::protocol::TypeId;
use crate::mysql::types::MySqlTypeInfo; use crate::mysql::types::MySqlTypeInfo;
use crate::mysql::MySql; use crate::mysql::MySql;
use crate::types::HasSqlType; use crate::types::Type;
/// The equivalent MySQL type for `f32` is `FLOAT`. /// The equivalent MySQL type for `f32` is `FLOAT`.
/// ///
@ -18,7 +18,7 @@ use crate::types::HasSqlType;
/// // (This is expected behavior for floating points and happens both in Rust and in MySQL) /// // (This is expected behavior for floating points and happens both in Rust and in MySQL)
/// assert_ne!(10.2f32 as f64, 10.2f64); /// assert_ne!(10.2f32 as f64, 10.2f64);
/// ``` /// ```
impl HasSqlType<f32> for MySql { impl Type<f32> for MySql {
fn type_info() -> MySqlTypeInfo { fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::FLOAT) MySqlTypeInfo::new(TypeId::FLOAT)
} }
@ -40,7 +40,7 @@ impl Decode<MySql> for f32 {
/// ///
/// Note that `DOUBLE` is a floating-point type and cannot represent some fractional values /// Note that `DOUBLE` is a floating-point type and cannot represent some fractional values
/// exactly. /// exactly.
impl HasSqlType<f64> for MySql { impl Type<f64> for MySql {
fn type_info() -> MySqlTypeInfo { fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::DOUBLE) MySqlTypeInfo::new(TypeId::DOUBLE)
} }

View File

@ -6,9 +6,9 @@ use crate::io::{Buf, BufMut};
use crate::mysql::protocol::TypeId; use crate::mysql::protocol::TypeId;
use crate::mysql::types::MySqlTypeInfo; use crate::mysql::types::MySqlTypeInfo;
use crate::mysql::MySql; use crate::mysql::MySql;
use crate::types::HasSqlType; use crate::types::Type;
impl HasSqlType<i8> for MySql { impl Type<i8> for MySql {
fn type_info() -> MySqlTypeInfo { fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::TINY_INT) MySqlTypeInfo::new(TypeId::TINY_INT)
} }
@ -26,7 +26,7 @@ impl Decode<MySql> for i8 {
} }
} }
impl HasSqlType<i16> for MySql { impl Type<i16> for MySql {
fn type_info() -> MySqlTypeInfo { fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::SMALL_INT) MySqlTypeInfo::new(TypeId::SMALL_INT)
} }
@ -44,7 +44,7 @@ impl Decode<MySql> for i16 {
} }
} }
impl HasSqlType<i32> for MySql { impl Type<i32> for MySql {
fn type_info() -> MySqlTypeInfo { fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::INT) MySqlTypeInfo::new(TypeId::INT)
} }
@ -62,7 +62,7 @@ impl Decode<MySql> for i32 {
} }
} }
impl HasSqlType<i64> for MySql { impl Type<i64> for MySql {
fn type_info() -> MySqlTypeInfo { fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::BIG_INT) MySqlTypeInfo::new(TypeId::BIG_INT)
} }

View File

@ -8,9 +8,9 @@ use crate::mysql::io::{BufExt, BufMutExt};
use crate::mysql::protocol::TypeId; use crate::mysql::protocol::TypeId;
use crate::mysql::types::MySqlTypeInfo; use crate::mysql::types::MySqlTypeInfo;
use crate::mysql::MySql; use crate::mysql::MySql;
use crate::types::HasSqlType; use crate::types::Type;
impl HasSqlType<str> for MySql { impl Type<str> for MySql {
fn type_info() -> MySqlTypeInfo { fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo { MySqlTypeInfo {
id: TypeId::TEXT, id: TypeId::TEXT,
@ -28,9 +28,9 @@ impl Encode<MySql> for str {
} }
// TODO: Do we need the [HasSqlType] for String // TODO: Do we need the [HasSqlType] for String
impl HasSqlType<String> for MySql { impl Type<String> for MySql {
fn type_info() -> MySqlTypeInfo { fn type_info() -> MySqlTypeInfo {
<Self as HasSqlType<&str>>::type_info() <Self as Type<&str>>::type_info()
} }
} }

View File

@ -6,9 +6,9 @@ use crate::io::{Buf, BufMut};
use crate::mysql::protocol::TypeId; use crate::mysql::protocol::TypeId;
use crate::mysql::types::MySqlTypeInfo; use crate::mysql::types::MySqlTypeInfo;
use crate::mysql::MySql; use crate::mysql::MySql;
use crate::types::HasSqlType; use crate::types::Type;
impl HasSqlType<u8> for MySql { impl Type<u8> for MySql {
fn type_info() -> MySqlTypeInfo { fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::unsigned(TypeId::TINY_INT) MySqlTypeInfo::unsigned(TypeId::TINY_INT)
} }
@ -26,7 +26,7 @@ impl Decode<MySql> for u8 {
} }
} }
impl HasSqlType<u16> for MySql { impl Type<u16> for MySql {
fn type_info() -> MySqlTypeInfo { fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::unsigned(TypeId::SMALL_INT) MySqlTypeInfo::unsigned(TypeId::SMALL_INT)
} }
@ -44,7 +44,7 @@ impl Decode<MySql> for u16 {
} }
} }
impl HasSqlType<u32> for MySql { impl Type<u32> for MySql {
fn type_info() -> MySqlTypeInfo { fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::unsigned(TypeId::INT) MySqlTypeInfo::unsigned(TypeId::INT)
} }
@ -62,7 +62,7 @@ impl Decode<MySql> for u32 {
} }
} }
impl HasSqlType<u64> for MySql { impl Type<u64> for MySql {
fn type_info() -> MySqlTypeInfo { fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::unsigned(TypeId::BIG_INT) MySqlTypeInfo::unsigned(TypeId::BIG_INT)
} }

View File

@ -1,10 +1,11 @@
use crate::{Connect, Connection}; use crate::{Connect, Connection, Executor};
use futures_core::future::BoxFuture; use futures_core::future::BoxFuture;
use std::ops::{Deref, DerefMut}; use std::ops::{Deref, DerefMut};
use std::sync::Arc; use std::sync::Arc;
use std::time::Instant; use std::time::Instant;
use super::inner::{DecrementSizeGuard, SharedPool}; use super::inner::{DecrementSizeGuard, SharedPool};
use crate::describe::Describe;
/// A connection checked out from [`Pool`][crate::Pool]. /// A connection checked out from [`Pool`][crate::Pool].
/// ///
@ -68,6 +69,20 @@ where
live.float(&self.pool).into_idle().close().await live.float(&self.pool).into_idle().close().await
}) })
} }
#[inline]
fn ping(&mut self) -> BoxFuture<crate::Result<()>> {
Box::pin(self.deref_mut().ping())
}
#[doc(hidden)]
#[inline]
fn describe<'e, 'q: 'e>(
&'e mut self,
query: &'q str,
) -> BoxFuture<'e, crate::Result<Describe<Self::Database>>> {
Box::pin(self.deref_mut().describe(query))
}
} }
/// Returns the connection to the [`Pool`][crate::Pool] it was checked-out from. /// Returns the connection to the [`Pool`][crate::Pool] it was checked-out from.
@ -168,8 +183,7 @@ impl<'s, C> Floating<'s, Idle<C>> {
where where
C: Connection, C: Connection,
{ {
// TODO self.live.raw.ping().await self.live.raw.ping().await
todo!()
} }
pub fn into_live(self) -> Floating<'s, Live<C>> { pub fn into_live(self) -> Floating<'s, Live<C>> {

View File

@ -8,84 +8,89 @@ use crate::{
describe::Describe, describe::Describe,
executor::Executor, executor::Executor,
pool::Pool, pool::Pool,
Database, Cursor, Database,
}; };
use super::PoolConnection; use super::PoolConnection;
use crate::database::HasCursor; use crate::database::HasCursor;
use crate::executor::Execute; use crate::executor::Execute;
impl<'p, C> Executor<'p> for &'p Pool<C> impl<'p, C, DB> Executor<'p> for &'p Pool<C>
where where
C: Connect, C: Connect<Database = DB>,
DB: Database<Connection = C>,
DB: for<'c, 'q> HasCursor<'c, 'q>,
for<'con> &'con mut C: Executor<'con>,
{
type Database = DB;
fn execute<'q, E>(self, query: E) -> <Self::Database as HasCursor<'p, 'q>>::Cursor
where
E: Execute<'q, Self::Database>,
{
DB::Cursor::from_pool(self, query)
}
#[inline]
fn execute_by_ref<'q, 'e, E>(
&'e mut self,
query: E,
) -> <Self::Database as HasCursor<'_, 'q>>::Cursor
where
E: Execute<'q, Self::Database>,
{
self.execute(query)
}
}
impl<'c, C, DB> Executor<'c> for &'c mut PoolConnection<C>
where
C: Connect<Database = DB>,
DB: Database<Connection = C>,
DB: for<'c2, 'q> HasCursor<'c2, 'q, Database = DB>,
for<'con> &'con mut C: Executor<'con>, for<'con> &'con mut C: Executor<'con>,
{ {
type Database = C::Database; type Database = C::Database;
fn execute<'q, E>(self, query: E) -> <Self::Database as HasCursor<'p>>::Cursor fn execute<'q, E>(self, query: E) -> <Self::Database as HasCursor<'c, 'q>>::Cursor
where where
E: Execute<'q, Self::Database>, E: Execute<'q, Self::Database>,
{ {
todo!() DB::Cursor::from_connection(&mut **self, query)
} }
#[inline]
fn execute_by_ref<'q, 'e, E>( fn execute_by_ref<'q, 'e, E>(
&'e mut self, &'e mut self,
query: E, query: E,
) -> <Self::Database as HasCursor<'_>>::Cursor ) -> <Self::Database as HasCursor<'_, 'q>>::Cursor
where where
E: Execute<'q, Self::Database>, E: Execute<'q, Self::Database>,
{ {
todo!() self.execute(query)
} }
} }
impl<'c, C> Executor<'c> for &'c mut PoolConnection<C> impl<C, DB> Executor<'static> for PoolConnection<C>
where where
C: Connect, C: Connect<Database = DB>,
for<'con> &'con mut C: Executor<'con>, DB: Database<Connection = C>,
DB: for<'c, 'q> HasCursor<'c, 'q, Database = DB>,
{ {
type Database = C::Database; type Database = DB;
fn execute<'q, E>(self, query: E) -> <Self::Database as HasCursor<'c>>::Cursor fn execute<'q, E>(self, query: E) -> <DB as HasCursor<'static, 'q>>::Cursor
where where
E: Execute<'q, Self::Database>, E: Execute<'q, Self::Database>,
{ {
todo!() DB::Cursor::from_connection(self, query)
} }
fn execute_by_ref<'q, 'e, E>( #[inline]
&'e mut self, fn execute_by_ref<'q, 'e, E>(&'e mut self, query: E) -> <DB as HasCursor<'_, 'q>>::Cursor
query: E,
) -> <Self::Database as HasCursor<'_>>::Cursor
where where
E: Execute<'q, Self::Database>, E: Execute<'q, Self::Database>,
{ {
todo!() DB::Cursor::from_connection(&mut **self, query)
}
}
impl<C> Executor<'static> for PoolConnection<C>
where
C: Connect,
// for<'con> &'con mut C: Executor<'con>,
{
type Database = C::Database;
fn execute<'q, E>(self, query: E) -> <Self::Database as HasCursor<'static>>::Cursor
where
E: Execute<'q, Self::Database>,
{
unimplemented!()
}
fn execute_by_ref<'q, 'e, E>(
&'e mut self,
query: E,
) -> <Self::Database as HasCursor<'_>>::Cursor
where
E: Execute<'q, Self::Database>,
{
todo!()
} }
} }

View File

@ -20,13 +20,15 @@ mod inner;
mod options; mod options;
pub use self::options::Builder; pub use self::options::Builder;
use crate::Database;
/// A pool of database connections. /// A pool of database connections.
pub struct Pool<C>(Arc<SharedPool<C>>); pub struct Pool<C>(Arc<SharedPool<C>>);
impl<C> Pool<C> impl<C, DB> Pool<C>
where where
C: Connect, C: Connect<Database = DB>,
DB: Database<Connection = C>,
{ {
/// Creates a connection pool with the default configuration. /// Creates a connection pool with the default configuration.
/// ///

View File

@ -2,6 +2,7 @@ use std::{marker::PhantomData, time::Duration};
use super::Pool; use super::Pool;
use crate::connection::Connect; use crate::connection::Connect;
use crate::Database;
/// Builder for [Pool]. /// Builder for [Pool].
pub struct Builder<C> { pub struct Builder<C> {
@ -9,7 +10,11 @@ pub struct Builder<C> {
options: Options, options: Options,
} }
impl<C> Builder<C> { impl<C, DB> Builder<C>
where
C: Connect<Database = DB>,
DB: Database<Connection = C>,
{
/// Get a new builder with default options. /// Get a new builder with default options.
/// ///
/// See the source of this method for current defaults. /// See the source of this method for current defaults.
@ -108,7 +113,11 @@ impl<C> Builder<C> {
} }
} }
impl<C> Default for Builder<C> { impl<C, DB> Default for Builder<C>
where
C: Connect<Database = DB>,
DB: Database<Connection = C>,
{
fn default() -> Self { fn default() -> Self {
Self::new() Self::new()
} }

View File

@ -3,7 +3,7 @@ use byteorder::{ByteOrder, NetworkEndian};
use crate::arguments::Arguments; use crate::arguments::Arguments;
use crate::encode::{Encode, IsNull}; use crate::encode::{Encode, IsNull};
use crate::io::BufMut; use crate::io::BufMut;
use crate::types::HasSqlType; use crate::types::Type;
use crate::Postgres; use crate::Postgres;
#[derive(Default)] #[derive(Default)]
@ -25,14 +25,13 @@ impl Arguments for PgArguments {
fn add<T>(&mut self, value: T) fn add<T>(&mut self, value: T)
where where
Self::Database: HasSqlType<T>, T: Type<Self::Database>,
T: Encode<Self::Database>, T: Encode<Self::Database>,
{ {
// TODO: When/if we receive types that do _not_ support BINARY, we need to check here // TODO: When/if we receive types that do _not_ support BINARY, we need to check here
// TODO: There is no need to be explicit unless we are expecting mixed BINARY / TEXT // TODO: There is no need to be explicit unless we are expecting mixed BINARY / TEXT
self.types self.types.push(<T as Type<Postgres>>::type_info().id.0);
.push(<Postgres as HasSqlType<T>>::type_info().id.0);
let pos = self.values.len(); let pos = self.values.len();

View File

@ -1,16 +1,26 @@
use std::convert::TryInto; use std::convert::TryInto;
use std::ops::Range;
use byteorder::NetworkEndian; use byteorder::NetworkEndian;
use futures_core::future::BoxFuture; use futures_core::future::BoxFuture;
use std::net::Shutdown; use futures_core::Future;
use futures_util::TryFutureExt;
use crate::cache::StatementCache; use crate::cache::StatementCache;
use crate::connection::{Connect, Connection}; use crate::connection::{Connect, Connection};
use crate::describe::{Column, Describe};
use crate::io::{Buf, BufStream, MaybeTlsStream}; use crate::io::{Buf, BufStream, MaybeTlsStream};
use crate::postgres::protocol::{self, Authentication, Decode, Encode, Message, StatementId}; use crate::postgres::protocol::{
use crate::postgres::{sasl, PgError}; self, Authentication, AuthenticationMd5, AuthenticationSasl, Decode, Encode, Message,
ParameterDescription, PasswordMessage, RowDescription, StartupMessage, StatementId, Terminate,
};
use crate::postgres::sasl;
use crate::postgres::stream::PgStream;
use crate::postgres::{PgError, PgTypeInfo};
use crate::url::Url; use crate::url::Url;
use crate::{Postgres, Result}; use crate::{Error, Executor, Postgres};
// TODO: TLS
/// An asynchronous connection to a [Postgres][super::Postgres] database. /// An asynchronous connection to a [Postgres][super::Postgres] database.
/// ///
@ -73,301 +83,279 @@ use crate::{Postgres, Result};
/// against the hostname in the server certificate, so they must be the same for the TLS /// against the hostname in the server certificate, so they must be the same for the TLS
/// upgrade to succeed. /// upgrade to succeed.
pub struct PgConnection { pub struct PgConnection {
pub(super) stream: BufStream<MaybeTlsStream>, pub(super) stream: PgStream,
// Map of query to statement id
pub(super) statement_cache: StatementCache<StatementId>,
// Next statement id
pub(super) next_statement_id: u32, pub(super) next_statement_id: u32,
pub(super) is_ready: bool,
// Process ID of the Backend // TODO: Think of a better way to do this, better name perhaps?
process_id: u32, pub(super) data_row_values_buf: Vec<Option<Range<u32>>>,
// Backend-unique key to use to send a cancel query message to the server
secret_key: u32,
// Is there a query in progress; are we ready to continue
pub(super) ready: bool,
} }
impl PgConnection { // https://www.postgresql.org/docs/12/protocol-flow.html#id-1.10.5.7.3
// https://www.postgresql.org/docs/12/protocol-flow.html#id-1.10.5.7.3 async fn startup(stream: &mut PgStream, url: &Url) -> crate::Result<()> {
async fn startup(&mut self, url: &Url) -> Result<()> { // Defaults to postgres@.../postgres
// Defaults to postgres@.../postgres let username = url.username().unwrap_or("postgres");
let username = url.username().unwrap_or("postgres"); let database = url.database().unwrap_or("postgres");
let database = url.database().unwrap_or("postgres");
// See this doc for more runtime parameters // See this doc for more runtime parameters
// https://www.postgresql.org/docs/12/runtime-config-client.html // https://www.postgresql.org/docs/12/runtime-config-client.html
let params = &[ let params = &[
("user", username), ("user", username),
("database", database), ("database", database),
// Sets the display format for date and time values, // Sets the display format for date and time values,
// as well as the rules for interpreting ambiguous date input values. // as well as the rules for interpreting ambiguous date input values.
("DateStyle", "ISO, MDY"), ("DateStyle", "ISO, MDY"),
// Sets the display format for interval values. // Sets the display format for interval values.
("IntervalStyle", "iso_8601"), ("IntervalStyle", "iso_8601"),
// Sets the time zone for displaying and interpreting time stamps. // Sets the time zone for displaying and interpreting time stamps.
("TimeZone", "UTC"), ("TimeZone", "UTC"),
// Adjust postgres to return percise values for floats // Adjust postgres to return percise values for floats
// NOTE: This is default in postgres 12+ // NOTE: This is default in postgres 12+
("extra_float_digits", "3"), ("extra_float_digits", "3"),
// Sets the client-side encoding (character set). // Sets the client-side encoding (character set).
("client_encoding", "UTF-8"), ("client_encoding", "UTF-8"),
]; ];
protocol::StartupMessage { params }.encode(self.stream.buffer_mut()); stream.write(StartupMessage { params });
self.stream.flush().await?; stream.flush().await?;
while let Some(message) = self.receive().await? { loop {
match message { match stream.read().await? {
Message::Authentication(auth) => { Message::Authentication => match Authentication::read(stream.buffer())? {
match *auth { Authentication::Ok => {
protocol::Authentication::Ok => { // do nothing. no password is needed to continue.
// Do nothing. No password is needed to continue. }
}
protocol::Authentication::ClearTextPassword => { Authentication::CleartextPassword => {
protocol::PasswordMessage::ClearText( stream.write(PasswordMessage::ClearText(
&url.password().unwrap_or_default(), &url.password().unwrap_or_default(),
) ));
.encode(self.stream.buffer_mut());
self.stream.flush().await?; stream.flush().await?;
} }
protocol::Authentication::Md5Password { salt } => { Authentication::Md5Password => {
protocol::PasswordMessage::Md5 { // TODO: Just reference the salt instead of returning a stack array
password: &url.password().unwrap_or_default(), // TODO: Better way to make sure we skip the first 4 bytes here
user: username, let data = AuthenticationMd5::read(&stream.buffer()[4..])?;
salt,
}
.encode(self.stream.buffer_mut());
self.stream.flush().await?; stream.write(PasswordMessage::Md5 {
} password: &url.password().unwrap_or_default(),
user: username,
salt: data.salt,
});
protocol::Authentication::Sasl { mechanisms } => { stream.flush().await?;
let mut has_sasl: bool = false; }
let mut has_sasl_plus: bool = false;
for mechanism in &*mechanisms { Authentication::Sasl => {
match &**mechanism { // TODO: Make this iterative for traversing the mechanisms to remove the allocation
"SCRAM-SHA-256" => { // TODO: Better way to make sure we skip the first 4 bytes here
has_sasl = true; let data = AuthenticationSasl::read(&stream.buffer()[4..])?;
}
"SCRAM-SHA-256-PLUS" => { let mut has_sasl: bool = false;
has_sasl_plus = true; let mut has_sasl_plus: bool = false;
}
_ => { for mechanism in &*data.mechanisms {
log::info!("unsupported auth mechanism: {}", mechanism); match &**mechanism {
} "SCRAM-SHA-256" => {
} has_sasl = true;
} }
if has_sasl || has_sasl_plus { "SCRAM-SHA-256-PLUS" => {
// TODO: Handle -PLUS differently if we're in a TLS stream has_sasl_plus = true;
sasl::authenticate( }
self,
username, _ => {
&url.password().unwrap_or_default(), log::info!("unsupported auth mechanism: {}", mechanism);
)
.await?;
} else {
return Err(protocol_err!(
"unsupported SASL auth mechanisms: {:?}",
mechanisms
)
.into());
} }
} }
}
auth => { if has_sasl || has_sasl_plus {
return Err(protocol_err!( // TODO: Handle -PLUS differently if we're in a TLS stream
"requires unimplemented authentication method: {:?}", sasl::authenticate(stream, username, &url.password().unwrap_or_default())
auth .await?;
) } else {
.into()); return Err(protocol_err!(
} "unsupported SASL auth mechanisms: {:?}",
data.mechanisms
)
.into());
} }
} }
Message::BackendKeyData(body) => { auth => {
self.process_id = body.process_id; return Err(
self.secret_key = body.secret_key; protocol_err!("requested unsupported authentication: {:?}", auth).into(),
);
} }
},
Message::ReadyForQuery(_) => { Message::BackendKeyData => {
// Connection fully established and ready to receive queries. // do nothing. we do not care about the server values here.
// todo: we should care and store these on the connection
}
Message::ParameterStatus => {
// do nothing. we do not care about the server values here.
}
Message::ReadyForQuery => {
// done. connection is now fully established and can accept
// queries for execution.
break;
}
type_ => {
return Err(protocol_err!("unexpected message: {:?}", type_).into());
}
}
}
Ok(())
}
// https://www.postgresql.org/docs/12/protocol-flow.html#id-1.10.5.7.10
async fn terminate(mut stream: PgStream) -> crate::Result<()> {
stream.write(Terminate);
stream.flush().await?;
stream.shutdown()?;
Ok(())
}
impl PgConnection {
pub(super) async fn new(url: crate::Result<Url>) -> crate::Result<Self> {
let url = url?;
let mut stream = PgStream::new(&url).await?;
startup(&mut stream, &url).await?;
Ok(Self {
stream,
data_row_values_buf: Vec::new(),
next_statement_id: 1,
is_ready: true,
})
}
pub(super) async fn wait_until_ready(&mut self) -> crate::Result<()> {
// depending on how the previous query finished we may need to continue
// pulling messages from the stream until we receive a [ReadyForQuery] message
// postgres sends the [ReadyForQuery] message when it's fully complete with processing
// the previous query
if !self.is_ready {
loop {
if let Message::ReadyForQuery = self.stream.read().await? {
// we are now ready to go
self.is_ready = true;
break; break;
} }
message => {
return Err(protocol_err!("received unexpected message: {:?}", message).into());
}
} }
} }
Ok(()) Ok(())
} }
// https://www.postgresql.org/docs/devel/protocol-flow.html#id-1.10.5.7.10 async fn describe<'e, 'q: 'e>(
async fn terminate(mut self) -> Result<()> { &'e mut self,
protocol::Terminate.encode(self.stream.buffer_mut()); query: &'q str,
) -> crate::Result<Describe<Postgres>> {
let statement = self.write_prepare(query, &Default::default());
self.write_describe(protocol::Describe::Statement(statement));
self.write_sync();
self.stream.flush().await?; self.stream.flush().await?;
self.stream.stream.shutdown(Shutdown::Both)?; self.wait_until_ready().await?;
Ok(()) let params = loop {
} match self.stream.read().await? {
Message::ParseComplete => {
// Wait and return the next message to be received from Postgres. // ignore complete messsage
pub(super) async fn receive(&mut self) -> Result<Option<Message>> { // continue
loop {
// Read the message header (id + len)
let mut header = ret_if_none!(self.stream.peek(5).await?);
let id = header.get_u8()?;
let len = (header.get_u32::<NetworkEndian>()? - 4) as usize;
// Read the message body
self.stream.consume(5);
let body = ret_if_none!(self.stream.peek(len).await?);
let message = match id {
b'N' | b'E' => Message::Response(Box::new(protocol::Response::decode(body)?)),
b'D' => Message::DataRow(protocol::DataRow::decode(body)?),
b'S' => {
Message::ParameterStatus(Box::new(protocol::ParameterStatus::decode(body)?))
}
b'Z' => Message::ReadyForQuery(protocol::ReadyForQuery::decode(body)?),
b'R' => Message::Authentication(Box::new(protocol::Authentication::decode(body)?)),
b'K' => Message::BackendKeyData(protocol::BackendKeyData::decode(body)?),
b'C' => Message::CommandComplete(protocol::CommandComplete::decode(body)?),
b'A' => Message::NotificationResponse(Box::new(
protocol::NotificationResponse::decode(body)?,
)),
b'1' => Message::ParseComplete,
b'2' => Message::BindComplete,
b'3' => Message::CloseComplete,
b'n' => Message::NoData,
b's' => Message::PortalSuspended,
b't' => Message::ParameterDescription(Box::new(
protocol::ParameterDescription::decode(body)?,
)),
b'T' => Message::RowDescription(Box::new(protocol::RowDescription::decode(body)?)),
id => {
return Err(protocol_err!("received unknown message id: {:?}", id).into());
}
};
self.stream.consume(len);
match message {
Message::ParameterStatus(_body) => {
// TODO: not sure what to do with these yet
} }
Message::Response(body) => { Message::ParameterDescription => {
if body.severity.is_error() { break ParameterDescription::read(self.stream.buffer())?;
// This is an error, stop the world and bubble as an error
return Err(PgError(body).into());
} else {
// This is a _warning_
// TODO: Log the warning
}
} }
message => { message => {
return Ok(Some(message)); return Err(protocol_err!(
"expected ParameterDescription; received {:?}",
message
)
.into());
} }
} };
}
}
}
impl PgConnection {
pub(super) async fn establish(url: Result<Url>) -> Result<Self> {
let url = url?;
let stream = MaybeTlsStream::connect(&url, 5432).await?;
let mut self_ = Self {
stream: BufStream::new(stream),
process_id: 0,
secret_key: 0,
// Important to start at 1 as 0 means "unnamed" in our protocol
next_statement_id: 1,
statement_cache: StatementCache::new(),
ready: true,
}; };
let ssl_mode = url.get_param("sslmode").unwrap_or("prefer".into()); let result = match self.stream.read().await? {
Message::NoData => None,
Message::RowDescription => Some(RowDescription::read(self.stream.buffer())?),
match &*ssl_mode { message => {
// TODO: on "allow" retry with TLS if startup fails return Err(protocol_err!(
"disable" | "allow" => (), "expected RowDescription or NoData; received {:?}",
message
#[cfg(feature = "tls")]
"prefer" => {
if !self_.try_ssl(&url, true, true).await? {
log::warn!("server does not support TLS, falling back to unsecured connection")
}
}
#[cfg(not(feature = "tls"))]
"prefer" => log::info!("compiled without TLS, skipping upgrade"),
#[cfg(feature = "tls")]
"require" | "verify-ca" | "verify-full" => {
if !self_
.try_ssl(
&url,
ssl_mode == "require", // false for both verify-ca and verify-full
ssl_mode != "verify-full", // false for only verify-full
)
.await?
{
return Err(tls_err!("Postgres server does not support TLS").into());
}
}
#[cfg(not(feature = "tls"))]
"require" | "verify-ca" | "verify-full" => {
return Err(tls_err!(
"sslmode {:?} unsupported; SQLx was compiled without `tls` feature",
ssl_mode
) )
.into()) .into());
} }
_ => return Err(tls_err!("unknown `sslmode` value: {:?}", ssl_mode).into()), };
}
self_.stream.clear_bufs(); Ok(Describe {
param_types: params
self_.startup(&url).await?; .ids
.iter()
Ok(self_) .map(|id| PgTypeInfo::new(*id))
.collect::<Vec<_>>()
.into_boxed_slice(),
result_columns: result
.map(|r| r.fields)
.unwrap_or_default()
.into_vec()
.into_iter()
// TODO: Should [Column] just wrap [protocol::Field] ?
.map(|field| Column {
name: field.name,
table_id: field.table_id,
type_info: PgTypeInfo::new(field.type_id),
})
.collect::<Vec<_>>()
.into_boxed_slice(),
})
} }
} }
impl Connect for PgConnection { impl Connect for PgConnection {
fn connect<T>(url: T) -> BoxFuture<'static, Result<PgConnection>> fn connect<T>(url: T) -> BoxFuture<'static, crate::Result<PgConnection>>
where where
T: TryInto<Url, Error = crate::Error>, T: TryInto<Url, Error = crate::Error>,
Self: Sized, Self: Sized,
{ {
Box::pin(PgConnection::establish(url.try_into())) Box::pin(PgConnection::new(url.try_into()))
} }
} }
impl Connection for PgConnection { impl Connection for PgConnection {
type Database = Postgres; type Database = Postgres;
fn close(self) -> BoxFuture<'static, Result<()>> { fn close(self) -> BoxFuture<'static, crate::Result<()>> {
Box::pin(self.terminate()) Box::pin(terminate(self.stream))
}
fn ping(&mut self) -> BoxFuture<crate::Result<()>> {
Box::pin(self.execute("SELECT 1").map_ok(|_| ()))
}
#[doc(hidden)]
fn describe<'e, 'q: 'e>(
&'e mut self,
query: &'q str,
) -> BoxFuture<'e, crate::Result<Describe<Self::Database>>> {
Box::pin(self.describe(query))
} }
} }

View File

@ -1,56 +1,329 @@
use std::future::Future; use std::future::Future;
use std::mem;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use async_stream::try_stream;
use futures_core::future::BoxFuture; use futures_core::future::BoxFuture;
use futures_core::stream::BoxStream; use futures_core::stream::BoxStream;
use crate::cursor::Cursor; use crate::connection::{ConnectionSource, MaybeOwnedConnection};
use crate::cursor::{Cursor, MapRowFn};
use crate::database::HasRow; use crate::database::HasRow;
use crate::postgres::protocol::StatementId; use crate::executor::Execute;
use crate::postgres::PgConnection; use crate::pool::{Pool, PoolConnection};
use crate::Postgres; use crate::postgres::protocol::{CommandComplete, DataRow, Message, StatementId};
use crate::postgres::{PgArguments, PgConnection, PgRow};
use crate::{Database, Postgres};
pub struct PgCursor<'a> { enum State<'c, 'q> {
statement: StatementId, Query(&'q str, Option<PgArguments>),
connection: &'a mut PgConnection, NextRow,
// Used for `impl Future`
Resolve(BoxFuture<'c, crate::Result<MaybeOwnedConnection<'c, PgConnection>>>),
AffectedRows(BoxFuture<'c, crate::Result<u64>>),
} }
impl<'a> PgCursor<'a> { pub struct PgCursor<'c, 'q> {
pub(super) fn from_connection( source: ConnectionSource<'c, PgConnection>,
connection: &'a mut PgConnection, state: State<'c, 'q>,
statement: StatementId, }
) -> Self {
impl<'c, 'q> Cursor<'c, 'q> for PgCursor<'c, 'q> {
type Database = Postgres;
#[doc(hidden)]
fn from_pool<E>(pool: &Pool<<Self::Database as Database>::Connection>, query: E) -> Self
where
Self: Sized,
E: Execute<'q, Self::Database>,
{
let (query, arguments) = query.into_parts();
Self { Self {
connection, // note: pool is internally reference counted
statement, source: ConnectionSource::Pool(pool.clone()),
state: State::Query(query, arguments),
}
}
#[doc(hidden)]
fn from_connection<E, C>(conn: C, query: E) -> Self
where
Self: Sized,
C: Into<MaybeOwnedConnection<'c, <Self::Database as Database>::Connection>>,
E: Execute<'q, Self::Database>,
{
let (query, arguments) = query.into_parts();
Self {
// note: pool is internally reference counted
source: ConnectionSource::Connection(conn.into()),
state: State::Query(query, arguments),
}
}
fn first(self) -> BoxFuture<'c, crate::Result<Option<<Self::Database as HasRow<'c>>::Row>>>
where
'q: 'c,
{
Box::pin(first(self))
}
fn next(&mut self) -> BoxFuture<crate::Result<Option<<Self::Database as HasRow<'_>>::Row>>> {
Box::pin(next(self))
}
fn map<T, F>(mut self, f: F) -> BoxStream<'c, crate::Result<T>>
where
F: MapRowFn<Self::Database, T>,
T: 'c + Send + Unpin,
'q: 'c,
{
Box::pin(try_stream! {
while let Some(row) = self.next().await? {
yield f.call(row);
}
})
}
}
impl<'s, 'q> Future for PgCursor<'s, 'q> {
type Output = crate::Result<u64>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop {
match &mut self.state {
State::Query(q, arguments) => {
// todo: existential types can remove both the boxed futures
// and this allocation
let query = q.to_owned();
let arguments = mem::take(arguments);
self.state = State::Resolve(Box::pin(resolve(
mem::take(&mut self.source),
query,
arguments,
)));
}
State::Resolve(fut) => {
match fut.as_mut().poll(cx) {
Poll::Pending => {
return Poll::Pending;
}
Poll::Ready(conn) => {
let conn = conn?;
self.state = State::AffectedRows(Box::pin(affected_rows(conn)));
// continue
}
}
}
State::NextRow => {
panic!("PgCursor must not be polled after being used");
}
State::AffectedRows(fut) => {
return fut.as_mut().poll(cx);
}
}
} }
} }
} }
impl<'a> Cursor<'a> for PgCursor<'a> { // write out query to the connection stream
type Database = Postgres; async fn write(
conn: &mut PgConnection,
query: &str,
arguments: Option<PgArguments>,
) -> crate::Result<()> {
// TODO: Handle [arguments] being None. This should be a SIMPLE query.
let arguments = arguments.unwrap();
fn first(self) -> BoxFuture<'a, crate::Result<Option<<Self::Database as HasRow>::Row>>> { // Check the statement cache for a statement ID that matches the given query
todo!() // If it doesn't exist, we generate a new statement ID and write out [Parse] to the
} // connection command buffer
let statement = conn.write_prepare(query, &arguments);
fn next(&mut self) -> BoxFuture<crate::Result<Option<<Self::Database as HasRow>::Row>>> { // Next, [Bind] attaches the arguments to the statement and creates a named portal
todo!() conn.write_bind("", statement, &arguments);
}
fn map<T, F>(self, f: F) -> BoxStream<'a, crate::Result<T>> // Next, [Describe] will return the expected result columns and types
where // Conditionally run [Describe] only if the results have not been cached
F: Fn(<Self::Database as HasRow>::Row) -> T, // if !self.statement_cache.has_columns(statement) {
{ // self.write_describe(protocol::Describe::Portal(""));
todo!() // }
}
// Next, [Execute] then executes the named portal
conn.write_execute("", 0);
// Finally, [Sync] asks postgres to process the messages that we sent and respond with
// a [ReadyForQuery] message when it's completely done. Theoretically, we could send
// dozens of queries before a [Sync] and postgres can handle that. Execution on the server
// is still serial but it would reduce round-trips. Some kind of builder pattern that is
// termed batching might suit this.
conn.write_sync();
conn.wait_until_ready().await?;
conn.stream.flush().await?;
conn.is_ready = false;
Ok(())
} }
impl<'a> Future for PgCursor<'a> { async fn resolve(
type Output = crate::Result<u64>; mut source: ConnectionSource<'_, PgConnection>,
query: String,
arguments: Option<PgArguments>,
) -> crate::Result<MaybeOwnedConnection<'_, PgConnection>> {
let mut conn = source.resolve_by_ref().await?;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { write(&mut *conn, &query, arguments).await?;
todo!()
} Ok(source.into_connection())
}
async fn affected_rows(mut conn: MaybeOwnedConnection<'_, PgConnection>) -> crate::Result<u64> {
conn.wait_until_ready().await?;
conn.stream.flush().await?;
conn.is_ready = false;
let mut rows = 0;
loop {
match conn.stream.read().await? {
Message::ParseComplete | Message::BindComplete => {
// ignore x_complete messages
}
Message::DataRow => {
// ignore rows
// TODO: should we log or something?
}
Message::CommandComplete => {
rows += CommandComplete::read(conn.stream.buffer())?.affected_rows;
}
Message::ReadyForQuery => {
// done
break;
}
message => {
return Err(protocol_err!("unexpected message: {:?}", message).into());
}
}
}
Ok(rows)
}
async fn next<'a, 'c: 'a, 'q: 'a>(
cursor: &'a mut PgCursor<'c, 'q>,
) -> crate::Result<Option<PgRow<'a>>> {
let mut conn = cursor.source.resolve_by_ref().await?;
match cursor.state {
State::Query(q, ref mut arguments) => {
// write out the query to the connection
write(&mut *conn, q, arguments.take()).await?;
// next time we come through here, skip this block
cursor.state = State::NextRow;
}
State::Resolve(_) | State::AffectedRows(_) => {
panic!("`PgCursor` must not be used after being polled");
}
State::NextRow => {
// grab the next row
}
}
loop {
match conn.stream.read().await? {
Message::ParseComplete | Message::BindComplete => {
// ignore x_complete messages
}
Message::CommandComplete => {
// no more rows
break;
}
Message::DataRow => {
let data = DataRow::read(&mut *conn)?;
return Ok(Some(PgRow {
connection: conn,
columns: Arc::default(),
data,
}));
}
message => {
return Err(protocol_err!("unexpected message: {:?}", message).into());
}
}
}
Ok(None)
}
async fn first<'c, 'q>(mut cursor: PgCursor<'c, 'q>) -> crate::Result<Option<PgRow<'c>>> {
let mut conn = cursor.source.resolve().await?;
match cursor.state {
State::Query(q, ref mut arguments) => {
// write out the query to the connection
write(&mut conn, q, arguments.take()).await?;
}
State::NextRow => {
// just grab the next row as the first
}
State::Resolve(_) | State::AffectedRows(_) => {
panic!("`PgCursor` must not be used after being polled");
}
}
loop {
match conn.stream.read().await? {
Message::ParseComplete | Message::BindComplete => {
// ignore x_complete messages
}
Message::CommandComplete => {
// no more rows
break;
}
Message::DataRow => {
let data = DataRow::read(&mut conn)?;
return Ok(Some(PgRow {
connection: conn,
columns: Arc::default(),
data,
}));
}
message => {
return Err(protocol_err!("unexpected message: {:?}", message).into());
}
}
}
Ok(None)
} }

View File

@ -13,18 +13,18 @@ impl Database for Postgres {
type TableId = u32; type TableId = u32;
} }
impl HasRow for Postgres { impl<'a> HasRow<'a> for Postgres {
// TODO: Can we drop the `type Database = _` // TODO: Can we drop the `type Database = _`
type Database = Postgres; type Database = Postgres;
type Row = super::PgRow; type Row = super::PgRow<'a>;
} }
impl<'a> HasCursor<'a> for Postgres { impl<'s, 'q> HasCursor<'s, 'q> for Postgres {
// TODO: Can we drop the `type Database = _` // TODO: Can we drop the `type Database = _`
type Database = Postgres; type Database = Postgres;
type Cursor = super::PgCursor<'a>; type Cursor = super::PgCursor<'s, 'q>;
} }
impl<'a> HasRawValue<'a> for Postgres { impl<'a> HasRawValue<'a> for Postgres {

View File

@ -1,7 +1,7 @@
use crate::error::DatabaseError; use crate::error::DatabaseError;
use crate::postgres::protocol::Response; use crate::postgres::protocol::Response;
pub struct PgError(pub(super) Box<Response>); pub struct PgError(pub(super) Response);
impl DatabaseError for PgError { impl DatabaseError for PgError {
fn message(&self) -> &str { fn message(&self) -> &str {

View File

@ -2,37 +2,36 @@ use std::collections::HashMap;
use std::io; use std::io;
use std::sync::Arc; use std::sync::Arc;
use crate::cursor::Cursor;
use crate::executor::{Execute, Executor}; use crate::executor::{Execute, Executor};
use crate::postgres::protocol::{self, Encode, Message, StatementId, TypeFormat}; use crate::postgres::protocol::{self, Encode, StatementId, TypeFormat};
use crate::postgres::{PgArguments, PgCursor, PgRow, PgTypeInfo, Postgres}; use crate::postgres::{PgArguments, PgConnection, PgCursor, PgRow, PgTypeInfo, Postgres};
impl super::PgConnection { impl PgConnection {
fn write_prepare(&mut self, query: &str, args: &PgArguments) -> StatementId { pub(crate) fn write_prepare(&mut self, query: &str, args: &PgArguments) -> StatementId {
if let Some(&id) = self.statement_cache.get(query) { // TODO: check query cache
id
} else {
let id = StatementId(self.next_statement_id);
self.next_statement_id += 1;
protocol::Parse { let id = StatementId(self.next_statement_id);
statement: id,
query,
param_types: &*args.types,
}
.encode(self.stream.buffer_mut());
self.statement_cache.put(query.to_owned(), id); self.next_statement_id += 1;
id self.stream.write(protocol::Parse {
} statement: id,
query,
param_types: &*args.types,
});
// TODO: write to query cache
id
} }
fn write_describe(&mut self, d: protocol::Describe) { pub(crate) fn write_describe(&mut self, d: protocol::Describe) {
d.encode(self.stream.buffer_mut()) self.stream.write(d);
} }
fn write_bind(&mut self, portal: &str, statement: StatementId, args: &PgArguments) { pub(crate) fn write_bind(&mut self, portal: &str, statement: StatementId, args: &PgArguments) {
protocol::Bind { self.stream.write(protocol::Bind {
portal, portal,
statement, statement,
formats: &[TypeFormat::Binary], formats: &[TypeFormat::Binary],
@ -40,59 +39,30 @@ impl super::PgConnection {
values_len: args.types.len() as i16, values_len: args.types.len() as i16,
values: &*args.values, values: &*args.values,
result_formats: &[TypeFormat::Binary], result_formats: &[TypeFormat::Binary],
} });
.encode(self.stream.buffer_mut());
} }
fn write_execute(&mut self, portal: &str, limit: i32) { pub(crate) fn write_execute(&mut self, portal: &str, limit: i32) {
protocol::Execute { portal, limit }.encode(self.stream.buffer_mut()); self.stream.write(protocol::Execute { portal, limit });
} }
fn write_sync(&mut self) { pub(crate) fn write_sync(&mut self) {
protocol::Sync.encode(self.stream.buffer_mut()); self.stream.write(protocol::Sync);
} }
} }
impl<'e> Executor<'e> for &'e mut super::PgConnection { impl<'e> Executor<'e> for &'e mut super::PgConnection {
type Database = Postgres; type Database = Postgres;
fn execute<'q, E>(self, query: E) -> PgCursor<'e> fn execute<'q, E>(self, query: E) -> PgCursor<'e, 'q>
where where
E: Execute<'q, Self::Database>, E: Execute<'q, Self::Database>,
{ {
let (query, arguments) = query.into_parts(); PgCursor::from_connection(self, query)
// TODO: Handle [arguments] being None. This should be a SIMPLE query.
let arguments = arguments.unwrap();
// Check the statement cache for a statement ID that matches the given query
// If it doesn't exist, we generate a new statement ID and write out [Parse] to the
// connection command buffer
let statement = self.write_prepare(query, &arguments);
// Next, [Bind] attaches the arguments to the statement and creates a named portal
self.write_bind("", statement, &arguments);
// Next, [Describe] will return the expected result columns and types
// Conditionally run [Describe] only if the results have not been cached
if !self.statement_cache.has_columns(statement) {
self.write_describe(protocol::Describe::Portal(""));
}
// Next, [Execute] then executes the named portal
self.write_execute("", 0);
// Finally, [Sync] asks postgres to process the messages that we sent and respond with
// a [ReadyForQuery] message when it's completely done. Theoretically, we could send
// dozens of queries before a [Sync] and postgres can handle that. Execution on the server
// is still serial but it would reduce round-trips. Some kind of builder pattern that is
// termed batching might suit this.
self.write_sync();
PgCursor::from_connection(self, statement)
} }
fn execute_by_ref<'q, E>(&mut self, query: E) -> PgCursor<'_> #[inline]
fn execute_by_ref<'q, E>(&mut self, query: E) -> PgCursor<'_, 'q>
where where
E: Execute<'q, Self::Database>, E: Execute<'q, Self::Database>,
{ {

View File

@ -17,7 +17,8 @@ mod executor;
mod protocol; mod protocol;
mod row; mod row;
mod sasl; mod sasl;
mod tls; mod stream;
// mod tls;
mod types; mod types;
/// An alias for [`Pool`][crate::Pool], specialized for **Postgres**. /// An alias for [`Pool`][crate::Pool], specialized for **Postgres**.

View File

@ -5,152 +5,176 @@ use std::str;
#[derive(Debug)] #[derive(Debug)]
pub enum Authentication { pub enum Authentication {
/// Authentication was successful. /// The authentication exchange is successfully completed.
Ok, Ok,
/// Kerberos V5 authentication is required. /// The frontend must now take part in a Kerberos V5 authentication dialog (not described
/// here, part of the Kerberos specification) with the server. If this is successful,
/// the server responds with an `AuthenticationOk`, otherwise it responds
/// with an `ErrorResponse`. This is no longer supported.
KerberosV5, KerberosV5,
/// A clear-text password is required. /// The frontend must now send a `PasswordMessage` containing the password in clear-text form.
ClearTextPassword, /// If this is the correct password, the server responds with an `AuthenticationOk`, otherwise it
/// responds with an `ErrorResponse`.
CleartextPassword,
/// An MD5-encrypted password is required. /// The frontend must now send a `PasswordMessage` containing the password (with user name)
Md5Password { salt: [u8; 4] }, /// encrypted via MD5, then encrypted again using the 4-byte random salt specified in the
/// `AuthenticationMD5Password` message. If this is the correct password, the server responds
/// with an `AuthenticationOk`, otherwise it responds with an `ErrorResponse`.
Md5Password,
/// An SCM credentials message is required. /// This response is only possible for local Unix-domain connections on platforms that support
/// SCM credential messages. The frontend must issue an SCM credential message and then
/// send a single data byte.
ScmCredential, ScmCredential,
/// GSSAPI authentication is required. /// The frontend must now initiate a GSSAPI negotiation. The frontend will send a
/// `GSSResponse` message with the first part of the GSSAPI data stream in response to this.
Gss, Gss,
/// SSPI authentication is required. /// The frontend must now initiate a SSPI negotiation.
/// The frontend will send a GSSResponse with the first part of the SSPI data stream in
/// response to this.
Sspi, Sspi,
/// This message contains GSSAPI or SSPI data. /// This message contains the response data from the previous step of GSSAPI
GssContinue { data: Box<[u8]> }, /// or SSPI negotiation.
GssContinue,
/// SASL authentication is required. /// The frontend must now initiate a SASL negotiation, using one of the SASL mechanisms
/// /// listed in the message.
/// The message body is a list of SASL authentication mechanisms, Sasl,
/// in the server's order of preference.
Sasl { mechanisms: Box<[Box<str>]> },
/// This message contains a SASL challenge. /// This message contains challenge data from the previous step of SASL negotiation.
SaslContinue(SaslContinue), SaslContinue,
/// SASL authentication has completed. /// SASL authentication has completed with additional mechanism-specific data for the client.
SaslFinal { data: Box<[u8]> }, SaslFinal,
}
impl Authentication {
pub fn read(mut buf: &[u8]) -> crate::Result<Self> {
Ok(match buf.get_u32::<NetworkEndian>()? {
0 => Authentication::Ok,
2 => Authentication::KerberosV5,
3 => Authentication::CleartextPassword,
5 => Authentication::Md5Password,
6 => Authentication::ScmCredential,
7 => Authentication::Gss,
8 => Authentication::GssContinue,
9 => Authentication::Sspi,
10 => Authentication::Sasl,
11 => Authentication::SaslContinue,
12 => Authentication::SaslFinal,
type_ => {
return Err(protocol_err!("unknown authentication message type: {}", type_).into());
}
})
}
} }
#[derive(Debug)] #[derive(Debug)]
pub struct SaslContinue { pub struct AuthenticationMd5 {
pub salt: [u8; 4],
}
impl AuthenticationMd5 {
pub fn read(mut buf: &[u8]) -> crate::Result<Self> {
let mut salt = [0_u8; 4];
salt.copy_from_slice(buf);
Ok(Self { salt })
}
}
#[derive(Debug)]
pub struct AuthenticationSasl {
pub mechanisms: Box<[Box<str>]>,
}
impl AuthenticationSasl {
pub fn read(mut buf: &[u8]) -> crate::Result<Self> {
let mut mechanisms = Vec::new();
while buf[0] != 0 {
mechanisms.push(buf.get_str_nul()?.into());
}
Ok(Self {
mechanisms: mechanisms.into_boxed_slice(),
})
}
}
#[derive(Debug)]
pub struct AuthenticationSaslContinue {
pub salt: Vec<u8>, pub salt: Vec<u8>,
pub iter_count: u32, pub iter_count: u32,
pub nonce: Vec<u8>, pub nonce: Vec<u8>,
pub data: String, pub data: String,
} }
impl Decode for Authentication { impl AuthenticationSaslContinue {
fn decode(mut buf: &[u8]) -> crate::Result<Self> { pub fn read(mut buf: &[u8]) -> crate::Result<Self> {
Ok(match buf.get_u32::<NetworkEndian>()? { let mut salt: Vec<u8> = Vec::new();
0 => Authentication::Ok, let mut nonce: Vec<u8> = Vec::new();
let mut iter_count: u32 = 0;
2 => Authentication::KerberosV5, let key_value: Vec<(char, &[u8])> = buf
.split(|byte| *byte == b',')
.map(|s| {
let (key, value) = s.split_at(1);
let value = value.split_at(1).1;
3 => Authentication::ClearTextPassword, (key[0] as char, value)
})
.collect();
5 => { for (key, value) in key_value.iter() {
let mut salt = [0_u8; 4]; match key {
salt.copy_from_slice(&buf); 's' => salt = value.to_vec(),
'r' => nonce = value.to_vec(),
Authentication::Md5Password { salt } 'i' => {
} let s = str::from_utf8(&value).map_err(|_| {
protocol_err!(
6 => Authentication::ScmCredential, "iteration count in sasl response was not a valid utf8 string"
)
7 => Authentication::Gss, })?;
iter_count = u32::from_str_radix(&s, 10).unwrap_or(0);
8 => {
let mut data = Vec::with_capacity(buf.len());
data.extend_from_slice(buf);
Authentication::GssContinue {
data: data.into_boxed_slice(),
}
}
9 => Authentication::Sspi,
10 => {
let mut mechanisms = Vec::new();
while buf[0] != 0 {
mechanisms.push(buf.get_str_nul()?.into());
} }
Authentication::Sasl { _ => {}
mechanisms: mechanisms.into_boxed_slice(),
}
} }
}
11 => { Ok(Self {
let mut salt: Vec<u8> = Vec::new(); salt: base64::decode(&salt).map_err(|_| {
let mut nonce: Vec<u8> = Vec::new(); protocol_err!("salt value response from postgres was not base64 encoded")
let mut iter_count: u32 = 0; })?,
nonce,
iter_count,
data: str::from_utf8(buf)
.map_err(|_| protocol_err!("SaslContinue response was not a valid utf8 string"))?
.to_string(),
})
}
}
let key_value: Vec<(char, &[u8])> = buf #[derive(Debug)]
.split(|byte| *byte == b',') pub struct AuthenticationSaslFinal {
.map(|s| { pub data: Box<[u8]>,
let (key, value) = s.split_at(1); }
let value = value.split_at(1).1;
(key[0] as char, value) impl AuthenticationSaslFinal {
}) pub fn read(mut buf: &[u8]) -> crate::Result<Self> {
.collect(); let mut data = Vec::with_capacity(buf.len());
data.extend_from_slice(buf);
for (key, value) in key_value.iter() { Ok(Self {
match key { data: data.into_boxed_slice(),
's' => salt = value.to_vec(),
'r' => nonce = value.to_vec(),
'i' => {
let s = str::from_utf8(&value).map_err(|_| {
protocol_err!(
"iteration count in sasl response was not a valid utf8 string"
)
})?;
iter_count = u32::from_str_radix(&s, 10).unwrap_or(0);
}
_ => {}
}
}
Authentication::SaslContinue(SaslContinue {
salt: base64::decode(&salt).map_err(|_| {
protocol_err!("salt value response from postgres was not base64 encoded")
})?,
nonce,
iter_count,
data: str::from_utf8(buf)
.map_err(|_| {
protocol_err!("SaslContinue response was not a valid utf8 string")
})?
.to_string(),
})
}
12 => {
let mut data = Vec::with_capacity(buf.len());
data.extend_from_slice(buf);
Authentication::SaslFinal {
data: data.into_boxed_slice(),
}
}
id => {
return Err(protocol_err!("unknown authentication response: {}", id).into());
}
}) })
} }
} }
@ -158,27 +182,25 @@ impl Decode for Authentication {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::{Authentication, Decode}; use super::{Authentication, Decode};
use crate::postgres::protocol::authentication::AuthenticationMd5;
use matches::assert_matches; use matches::assert_matches;
const AUTH_OK: &[u8] = b"\0\0\0\0"; const AUTH_OK: &[u8] = b"\0\0\0\0";
const AUTH_MD5: &[u8] = b"\0\0\0\x05\x93\x189\x98"; const AUTH_MD5: &[u8] = b"\0\0\0\x05\x93\x189\x98";
#[test] #[test]
fn it_decodes_auth_ok() { fn it_reads_auth_ok() {
let m = Authentication::decode(AUTH_OK).unwrap(); let m = Authentication::read(AUTH_OK).unwrap();
assert_matches!(m, Authentication::Ok); assert_matches!(m, Authentication::Ok);
} }
#[test] #[test]
fn it_decodes_auth_md5_password() { fn it_reads_auth_md5_password() {
let m = Authentication::decode(AUTH_MD5).unwrap(); let m = Authentication::read(AUTH_MD5).unwrap();
let data = AuthenticationMd5::read(&AUTH_MD5[4..]).unwrap();
assert_matches!( assert_matches!(m, Authentication::Md5Password);
m, assert_eq!(data.salt, [147, 24, 57, 152]);
Authentication::Md5Password {
salt: [147, 24, 57, 152]
}
);
} }
} }

View File

@ -6,8 +6,8 @@ pub struct CommandComplete {
pub affected_rows: u64, pub affected_rows: u64,
} }
impl Decode for CommandComplete { impl CommandComplete {
fn decode(mut buf: &[u8]) -> crate::Result<Self> { pub(crate) fn read(mut buf: &[u8]) -> crate::Result<Self> {
// Attempt to parse the last word in the command tag as an integer // Attempt to parse the last word in the command tag as an integer
// If it can't be parsed, the tag is probably "CREATE TABLE" or something // If it can't be parsed, the tag is probably "CREATE TABLE" or something
// and we should return 0 rows // and we should return 0 rows
@ -35,29 +35,29 @@ mod tests {
const COMMAND_COMPLETE_BEGIN: &[u8] = b"BEGIN\0"; const COMMAND_COMPLETE_BEGIN: &[u8] = b"BEGIN\0";
#[test] #[test]
fn it_decodes_command_complete_for_insert() { fn it_reads_command_complete_for_insert() {
let message = CommandComplete::decode(COMMAND_COMPLETE_INSERT).unwrap(); let message = CommandComplete::read(COMMAND_COMPLETE_INSERT).unwrap();
assert_eq!(message.affected_rows, 1); assert_eq!(message.affected_rows, 1);
} }
#[test] #[test]
fn it_decodes_command_complete_for_update() { fn it_reads_command_complete_for_update() {
let message = CommandComplete::decode(COMMAND_COMPLETE_UPDATE).unwrap(); let message = CommandComplete::read(COMMAND_COMPLETE_UPDATE).unwrap();
assert_eq!(message.affected_rows, 512); assert_eq!(message.affected_rows, 512);
} }
#[test] #[test]
fn it_decodes_command_complete_for_begin() { fn it_reads_command_complete_for_begin() {
let message = CommandComplete::decode(COMMAND_COMPLETE_BEGIN).unwrap(); let message = CommandComplete::read(COMMAND_COMPLETE_BEGIN).unwrap();
assert_eq!(message.affected_rows, 0); assert_eq!(message.affected_rows, 0);
} }
#[test] #[test]
fn it_decodes_command_complete_for_create_table() { fn it_reads_command_complete_for_create_table() {
let message = CommandComplete::decode(COMMAND_COMPLETE_CREATE_TABLE).unwrap(); let message = CommandComplete::read(COMMAND_COMPLETE_CREATE_TABLE).unwrap();
assert_eq!(message.affected_rows, 0); assert_eq!(message.affected_rows, 0);
} }

View File

@ -1,34 +1,49 @@
use crate::io::{Buf, ByteStr}; use crate::io::{Buf, ByteStr};
use crate::postgres::protocol::Decode; use crate::postgres::protocol::Decode;
use crate::postgres::PgConnection;
use byteorder::NetworkEndian; use byteorder::NetworkEndian;
use std::fmt::{self, Debug}; use std::fmt::{self, Debug};
use std::ops::Range; use std::ops::Range;
pub struct DataRow { pub struct DataRow {
buffer: Box<[u8]>, len: u16,
values: Box<[Option<Range<u32>>]>,
} }
impl DataRow { impl DataRow {
pub fn len(&self) -> usize { pub fn len(&self) -> usize {
self.values.len() self.len as usize
} }
pub fn get(&self, index: usize) -> Option<&[u8]> { pub fn get<'a>(
let range = self.values[index].as_ref()?; &self,
buffer: &'a [u8],
values: &[Option<Range<u32>>],
index: usize,
) -> Option<&'a [u8]> {
let range = values[index].as_ref()?;
Some(&self.buffer[(range.start as usize)..(range.end as usize)]) Some(&buffer[(range.start as usize)..(range.end as usize)])
} }
} }
impl Decode for DataRow { impl DataRow {
fn decode(mut buf: &[u8]) -> crate::Result<Self> { pub(crate) fn read<'a>(
let len = buf.get_u16::<NetworkEndian>()? as usize; connection: &mut PgConnection,
let buffer: Box<[u8]> = buf.into(); // buffer: &'a [u8],
let mut values = Vec::with_capacity(len); // values: &'a mut Vec<Option<Range<u32>>>,
let mut index = 4; ) -> crate::Result<Self> {
let buffer = connection.stream.buffer();
let values = &mut connection.data_row_values_buf;
while values.len() < len { values.clear();
let mut buf = buffer;
let len = buf.get_u16::<NetworkEndian>()?;
let mut index = 6;
while values.len() < (len as usize) {
// The length of the column value, in bytes (this count does not include itself). // The length of the column value, in bytes (this count does not include itself).
// Can be zero. As a special case, -1 indicates a NULL column value. // Can be zero. As a special case, -1 indicates a NULL column value.
// No value bytes follow in the NULL case. // No value bytes follow in the NULL case.
@ -46,26 +61,7 @@ impl Decode for DataRow {
} }
} }
Ok(Self { Ok(Self { len })
values: values.into_boxed_slice(),
buffer,
})
}
}
impl Debug for DataRow {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "DataRow(")?;
let len = self.values.len();
f.debug_list()
.entries((0..len).map(|i| self.get(i).map(ByteStr)))
.finish()?;
write!(f, ")")?;
Ok(())
} }
} }
@ -76,18 +72,14 @@ mod tests {
const DATA_ROW: &[u8] = b"\0\x03\0\0\0\x011\0\0\0\x012\0\0\0\x013"; const DATA_ROW: &[u8] = b"\0\x03\0\0\0\x011\0\0\0\x012\0\0\0\x013";
#[test] #[test]
fn it_decodes_data_row() { fn it_reads_data_row() {
let m = DataRow::decode(DATA_ROW).unwrap(); let mut values = Vec::new();
let m = DataRow::read(DATA_ROW, &mut values).unwrap();
assert_eq!(m.values.len(), 3); assert_eq!(m.len, 3);
assert_eq!(m.get(0), Some(&b"1"[..])); assert_eq!(m.get(DATA_ROW, &values, 0), Some(&b"1"[..]));
assert_eq!(m.get(1), Some(&b"2"[..])); assert_eq!(m.get(DATA_ROW, &values, 1), Some(&b"2"[..]));
assert_eq!(m.get(2), Some(&b"3"[..])); assert_eq!(m.get(DATA_ROW, &values, 2), Some(&b"3"[..]));
assert_eq!(
format!("{:?}", m),
"DataRow([Some(b\"1\"), Some(b\"2\"), Some(b\"3\")])"
);
} }
} }

View File

@ -1,24 +1,57 @@
use std::convert::TryFrom;
use crate::postgres::protocol::{ use crate::postgres::protocol::{
Authentication, BackendKeyData, CommandComplete, DataRow, NotificationResponse, Authentication, BackendKeyData, CommandComplete, DataRow, NotificationResponse,
ParameterDescription, ParameterStatus, ReadyForQuery, Response, RowDescription, ParameterDescription, ParameterStatus, ReadyForQuery, Response, RowDescription,
}; };
#[derive(Debug)] #[derive(Debug, Copy, Clone)]
#[repr(u8)] #[repr(u8)]
pub enum Message { pub enum Message {
Authentication(Box<Authentication>), Authentication,
ParameterStatus(Box<ParameterStatus>), BackendKeyData,
BackendKeyData(BackendKeyData),
ReadyForQuery(ReadyForQuery),
CommandComplete(CommandComplete),
DataRow(DataRow),
Response(Box<Response>),
NotificationResponse(Box<NotificationResponse>),
ParseComplete,
BindComplete, BindComplete,
CloseComplete, CloseComplete,
CommandComplete,
DataRow,
NoData, NoData,
NotificationResponse,
ParameterDescription,
ParameterStatus,
ParseComplete,
PortalSuspended, PortalSuspended,
ParameterDescription(Box<ParameterDescription>), ReadyForQuery,
RowDescription(Box<RowDescription>), NoticeResponse,
ErrorResponse,
RowDescription,
}
impl TryFrom<u8> for Message {
type Error = crate::Error;
fn try_from(type_: u8) -> crate::Result<Self> {
// https://www.postgresql.org/docs/12/protocol-message-formats.html
Ok(match type_ {
b'E' => Message::ErrorResponse,
b'N' => Message::NoticeResponse,
b'D' => Message::DataRow,
b'S' => Message::ParameterStatus,
b'Z' => Message::ReadyForQuery,
b'R' => Message::Authentication,
b'K' => Message::BackendKeyData,
b'C' => Message::CommandComplete,
b'A' => Message::NotificationResponse,
b'1' => Message::ParseComplete,
b'2' => Message::BindComplete,
b'3' => Message::CloseComplete,
b'n' => Message::NoData,
b's' => Message::PortalSuspended,
b't' => Message::ParameterDescription,
b'T' => Message::RowDescription,
id => {
return Err(protocol_err!("unknown message: {:?}", id).into());
}
})
}
} }

View File

@ -58,7 +58,10 @@ mod row_description;
mod message; mod message;
pub use authentication::Authentication; pub use authentication::{
Authentication, AuthenticationMd5, AuthenticationSasl, AuthenticationSaslContinue,
AuthenticationSaslFinal,
};
pub use backend_key_data::BackendKeyData; pub use backend_key_data::BackendKeyData;
pub use command_complete::CommandComplete; pub use command_complete::CommandComplete;
pub use data_row::DataRow; pub use data_row::DataRow;

View File

@ -7,8 +7,8 @@ pub struct ParameterDescription {
pub ids: Box<[TypeId]>, pub ids: Box<[TypeId]>,
} }
impl Decode for ParameterDescription { impl ParameterDescription {
fn decode(mut buf: &[u8]) -> crate::Result<Self> { pub(crate) fn read(mut buf: &[u8]) -> crate::Result<Self> {
let cnt = buf.get_u16::<NetworkEndian>()? as usize; let cnt = buf.get_u16::<NetworkEndian>()? as usize;
let mut ids = Vec::with_capacity(cnt); let mut ids = Vec::with_capacity(cnt);
@ -27,9 +27,9 @@ mod test {
use super::{Decode, ParameterDescription}; use super::{Decode, ParameterDescription};
#[test] #[test]
fn it_decodes_parameter_description() { fn it_reads_parameter_description() {
let buf = b"\x00\x02\x00\x00\x00\x00\x00\x00\x05\x00"; let buf = b"\x00\x02\x00\x00\x00\x00\x00\x00\x05\x00";
let desc = ParameterDescription::decode(buf).unwrap(); let desc = ParameterDescription::read(buf).unwrap();
assert_eq!(desc.ids.len(), 2); assert_eq!(desc.ids.len(), 2);
assert_eq!(desc.ids[0].0, 0x0000_0000); assert_eq!(desc.ids[0].0, 0x0000_0000);
@ -37,9 +37,9 @@ mod test {
} }
#[test] #[test]
fn it_decodes_empty_parameter_description() { fn it_reads_empty_parameter_description() {
let buf = b"\x00\x00"; let buf = b"\x00\x00";
let desc = ParameterDescription::decode(buf).unwrap(); let desc = ParameterDescription::read(buf).unwrap();
assert_eq!(desc.ids.len(), 0); assert_eq!(desc.ids.len(), 0);
} }

View File

@ -65,8 +65,8 @@ pub struct Response {
pub routine: Option<Box<str>>, pub routine: Option<Box<str>>,
} }
impl Decode for Response { impl Response {
fn decode(mut buf: &[u8]) -> crate::Result<Self> { pub fn read(mut buf: &[u8]) -> crate::Result<Self> {
let mut code = None::<Box<str>>; let mut code = None::<Box<str>>;
let mut message = None::<Box<str>>; let mut message = None::<Box<str>>;
let mut severity = None::<Box<str>>; let mut severity = None::<Box<str>>;

View File

@ -18,8 +18,8 @@ pub struct Field {
pub type_format: TypeFormat, pub type_format: TypeFormat,
} }
impl Decode for RowDescription { impl RowDescription {
fn decode(mut buf: &[u8]) -> crate::Result<Self> { pub(crate) fn read(mut buf: &[u8]) -> crate::Result<Self> {
let cnt = buf.get_u16::<NetworkEndian>()? as usize; let cnt = buf.get_u16::<NetworkEndian>()? as usize;
let mut fields = Vec::with_capacity(cnt); let mut fields = Vec::with_capacity(cnt);
@ -57,7 +57,7 @@ mod test {
use super::{Decode, RowDescription}; use super::{Decode, RowDescription};
#[test] #[test]
fn it_decodes_row_description() { fn it_reads_row_description() {
#[rustfmt::skip] #[rustfmt::skip]
let buf = bytes! { let buf = bytes! {
// Number of Parameters // Number of Parameters
@ -82,7 +82,7 @@ mod test {
0_u8, 0_u8 // format_code 0_u8, 0_u8 // format_code
}; };
let desc = RowDescription::decode(&buf).unwrap(); let desc = RowDescription::read(&buf).unwrap();
assert_eq!(desc.fields.len(), 2); assert_eq!(desc.fields.len(), 2);
assert_eq!(desc.fields[0].type_id.0, 0x0000_0000); assert_eq!(desc.fields[0].type_id.0, 0x0000_0000);
@ -90,9 +90,9 @@ mod test {
} }
#[test] #[test]
fn it_decodes_empty_row_description() { fn it_reads_empty_row_description() {
let buf = b"\x00\x00"; let buf = b"\x00\x00";
let desc = RowDescription::decode(buf).unwrap(); let desc = RowDescription::read(buf).unwrap();
assert_eq!(desc.fields.len(), 0); assert_eq!(desc.fields.len(), 0);
} }

View File

@ -1,58 +1,58 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use crate::connection::MaybeOwnedConnection;
use crate::decode::Decode; use crate::decode::Decode;
use crate::pool::PoolConnection;
use crate::postgres::protocol::DataRow; use crate::postgres::protocol::DataRow;
use crate::postgres::Postgres; use crate::postgres::{PgConnection, Postgres};
use crate::row::{Row, RowIndex}; use crate::row::{Row, RowIndex};
use crate::types::HasSqlType; use crate::types::Type;
pub struct PgRow { pub struct PgRow<'c> {
pub(super) connection: MaybeOwnedConnection<'c, PgConnection>,
pub(super) data: DataRow, pub(super) data: DataRow,
pub(super) columns: Arc<HashMap<Box<str>, usize>>, pub(super) columns: Arc<HashMap<Box<str>, usize>>,
} }
impl Row for PgRow { impl<'c> Row<'c> for PgRow<'c> {
type Database = Postgres; type Database = Postgres;
fn len(&self) -> usize { fn len(&self) -> usize {
self.data.len() self.data.len()
} }
fn get<T, I>(&self, index: I) -> T fn try_get_raw<'i, I>(&'c self, index: I) -> crate::Result<Option<&'c [u8]>>
where where
Self::Database: HasSqlType<T>, I: RowIndex<'c, Self> + 'i,
I: RowIndex<Self>,
T: Decode<Self::Database>,
{ {
index.try_get(self).unwrap() index.try_get_raw(self)
} }
} }
impl RowIndex<PgRow> for usize { impl<'c> RowIndex<'c, PgRow<'c>> for usize {
fn try_get<T>(&self, row: &PgRow) -> crate::Result<T> fn try_get_raw(self, row: &'c PgRow<'c>) -> crate::Result<Option<&'c [u8]>> {
where Ok(row.data.get(
<PgRow as Row>::Database: HasSqlType<T>, row.connection.stream.buffer(),
T: Decode<<PgRow as Row>::Database>, &row.connection.data_row_values_buf,
{ self,
Ok(Decode::decode_nullable(row.data.get(*self))?) ))
} }
} }
impl RowIndex<PgRow> for &'_ str { // impl<'c> RowIndex<'c, PgRow<'c>> for &'_ str {
fn try_get<T>(&self, row: &PgRow) -> crate::Result<T> // fn try_get_raw(self, row: &'r PgRow<'c>) -> crate::Result<Option<&'c [u8]>> {
where // let index = row
<PgRow as Row>::Database: HasSqlType<T>, // .columns
T: Decode<<PgRow as Row>::Database>, // .get(self)
{ // .ok_or_else(|| crate::Error::ColumnNotFound((*self).into()))?;
let index = row //
.columns // Ok(row.data.get(
.get(*self) // row.connection.stream.buffer(),
.ok_or_else(|| crate::Error::ColumnNotFound((*self).into()))?; // &row.connection.data_row_values_buf,
let value = Decode::decode_nullable(row.data.get(*index))?; // *index,
// ))
// }
// }
Ok(value) // TODO: impl_from_row_for_row!(PgRow);
}
}
impl_from_row_for_row!(PgRow);

View File

@ -3,8 +3,10 @@ use rand::Rng;
use sha2::{Digest, Sha256}; use sha2::{Digest, Sha256};
use crate::postgres::protocol::{ use crate::postgres::protocol::{
hi, Authentication, Encode, Message, SaslInitialResponse, SaslResponse, hi, Authentication, AuthenticationSaslContinue, Encode, Message, SaslInitialResponse,
SaslResponse,
}; };
use crate::postgres::stream::PgStream;
use crate::postgres::PgConnection; use crate::postgres::PgConnection;
static GS2_HEADER: &'static str = "n,,"; static GS2_HEADER: &'static str = "n,,";
@ -43,7 +45,7 @@ fn nonce() -> String {
// Performs authenticiton using Simple Authentication Security Layer (SASL) which is what // Performs authenticiton using Simple Authentication Security Layer (SASL) which is what
// Postgres uses // Postgres uses
pub(super) async fn authenticate<T: AsRef<str>>( pub(super) async fn authenticate<T: AsRef<str>>(
conn: &mut PgConnection, stream: &mut PgStream,
username: T, username: T,
password: T, password: T,
) -> crate::Result<()> { ) -> crate::Result<()> {
@ -62,13 +64,18 @@ pub(super) async fn authenticate<T: AsRef<str>>(
client_first_message_bare = client_first_message_bare client_first_message_bare = client_first_message_bare
); );
SaslInitialResponse(&client_first_message).encode(conn.stream.buffer_mut()); stream.write(SaslInitialResponse(&client_first_message));
conn.stream.flush().await?; stream.flush().await?;
let server_first_message = conn.receive().await?; let server_first_message = stream.read().await?;
if let Message::Authentication = server_first_message {
let auth = Authentication::read(stream.buffer())?;
if let Authentication::SaslContinue = auth {
// todo: better way to indicate that we consumed just these 4 bytes?
let sasl = AuthenticationSaslContinue::read(&stream.buffer()[4..])?;
if let Some(Message::Authentication(auth)) = server_first_message {
if let Authentication::SaslContinue(sasl) = *auth {
let server_first_message = sasl.data; let server_first_message = sasl.data;
// SaltedPassword := Hi(Normalize(password), salt, i) // SaltedPassword := Hi(Normalize(password), salt, i)
@ -132,9 +139,11 @@ pub(super) async fn authenticate<T: AsRef<str>>(
client_proof = base64::encode(&client_proof) client_proof = base64::encode(&client_proof)
); );
SaslResponse(&client_final_message).encode(conn.stream.buffer_mut()); stream.write(SaslResponse(&client_final_message));
conn.stream.flush().await?; stream.flush().await?;
let _server_final_response = conn.receive().await?;
let _server_final_response = stream.read().await?;
// todo: assert that this was SaslFinal?
Ok(()) Ok(())
} else { } else {

View File

@ -0,0 +1,90 @@
use std::convert::TryInto;
use std::net::Shutdown;
use byteorder::NetworkEndian;
use crate::io::{Buf, BufStream, MaybeTlsStream};
use crate::postgres::protocol::{Encode, Message, Response};
use crate::postgres::PgError;
use crate::url::Url;
pub struct PgStream {
stream: BufStream<MaybeTlsStream>,
// Most recently received message
// Is referenced by our buffered stream
// Is initialized to ReadyForQuery/0 at the start
message: (Message, u32),
}
impl PgStream {
pub(super) async fn new(url: &Url) -> crate::Result<Self> {
let stream = MaybeTlsStream::connect(&url, 5432).await?;
Ok(Self {
stream: BufStream::new(stream),
message: (Message::ReadyForQuery, 0),
})
}
pub(super) fn shutdown(&self) -> crate::Result<()> {
Ok(self.stream.shutdown(Shutdown::Both)?)
}
#[inline]
pub(super) fn write<M>(&mut self, message: M)
where
M: Encode,
{
message.encode(self.stream.buffer_mut());
}
#[inline]
pub(super) async fn flush(&mut self) -> crate::Result<()> {
Ok(self.stream.flush().await?)
}
pub(super) async fn read(&mut self) -> crate::Result<Message> {
// https://www.postgresql.org/docs/12/protocol-overview.html#PROTOCOL-MESSAGE-CONCEPTS
// All communication is through a stream of messages. The first byte of a message
// identifies the message type, and the next four bytes specify the length of the rest of
// the message (this length count includes itself, but not the message-type byte).
if self.message.1 > 0 {
// If there is any data in our read buffer we need to make sure we flush that
// so reading will return the *next* message
self.stream.consume(self.message.1 as usize);
}
let mut header = self.stream.peek(4 + 1).await?;
let type_ = header.get_u8()?.try_into()?;
let length = header.get_u32::<NetworkEndian>()? - 4;
self.message = (type_, length);
self.stream.consume(4 + 1);
// Wait until there is enough data in the stream. We then return without actually
// inspecting the data. This is then looked at later through the [buffer] function
let _ = self.stream.peek(length as usize).await?;
if let Message::ErrorResponse = type_ {
// This is an error, bubble up as one immediately
return Err(crate::Error::Database(Box::new(PgError(Response::read(
self.stream.buffer(),
)?))));
}
Ok(type_)
}
/// Returns a reference to the internally buffered message.
///
/// This is the body of the message identified by the most recent call
/// to `read`.
#[inline]
pub(super) fn buffer(&self) -> &[u8] {
&self.stream.buffer()[..(self.message.1 as usize)]
}
}

View File

@ -3,15 +3,15 @@ use crate::encode::Encode;
use crate::postgres::protocol::TypeId; use crate::postgres::protocol::TypeId;
use crate::postgres::types::PgTypeInfo; use crate::postgres::types::PgTypeInfo;
use crate::postgres::Postgres; use crate::postgres::Postgres;
use crate::types::HasSqlType; use crate::types::Type;
impl HasSqlType<bool> for Postgres { impl Type<Postgres> for bool {
fn type_info() -> PgTypeInfo { fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::BOOL) PgTypeInfo::new(TypeId::BOOL)
} }
} }
impl HasSqlType<[bool]> for Postgres { impl Type<Postgres> for [bool] {
fn type_info() -> PgTypeInfo { fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::ARRAY_BOOL) PgTypeInfo::new(TypeId::ARRAY_BOOL)
} }

View File

@ -3,24 +3,24 @@ use crate::encode::Encode;
use crate::postgres::protocol::TypeId; use crate::postgres::protocol::TypeId;
use crate::postgres::types::PgTypeInfo; use crate::postgres::types::PgTypeInfo;
use crate::postgres::Postgres; use crate::postgres::Postgres;
use crate::types::HasSqlType; use crate::types::Type;
impl HasSqlType<[u8]> for Postgres { impl Type<Postgres> for [u8] {
fn type_info() -> PgTypeInfo { fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::BYTEA) PgTypeInfo::new(TypeId::BYTEA)
} }
} }
impl HasSqlType<[&'_ [u8]]> for Postgres { impl Type<Postgres> for [&'_ [u8]] {
fn type_info() -> PgTypeInfo { fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::ARRAY_BYTEA) PgTypeInfo::new(TypeId::ARRAY_BYTEA)
} }
} }
// TODO: Do we need the [HasSqlType] here on the Vec? // TODO: Do we need the [HasSqlType] here on the Vec?
impl HasSqlType<Vec<u8>> for Postgres { impl Type<Postgres> for Vec<u8> {
fn type_info() -> PgTypeInfo { fn type_info() -> PgTypeInfo {
<Self as HasSqlType<[u8]>>::type_info() <[u8] as Type<Postgres>>::type_info()
} }
} }

View File

@ -8,27 +8,27 @@ use crate::encode::Encode;
use crate::postgres::protocol::TypeId; use crate::postgres::protocol::TypeId;
use crate::postgres::types::PgTypeInfo; use crate::postgres::types::PgTypeInfo;
use crate::postgres::Postgres; use crate::postgres::Postgres;
use crate::types::HasSqlType; use crate::types::Type;
impl HasSqlType<NaiveTime> for Postgres { impl Type<Postgres> for NaiveTime {
fn type_info() -> PgTypeInfo { fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::TIME) PgTypeInfo::new(TypeId::TIME)
} }
} }
impl HasSqlType<NaiveDate> for Postgres { impl Type<Postgres> for NaiveDate {
fn type_info() -> PgTypeInfo { fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::DATE) PgTypeInfo::new(TypeId::DATE)
} }
} }
impl HasSqlType<NaiveDateTime> for Postgres { impl Type<Postgres> for NaiveDateTime {
fn type_info() -> PgTypeInfo { fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::TIMESTAMP) PgTypeInfo::new(TypeId::TIMESTAMP)
} }
} }
impl<Tz> HasSqlType<DateTime<Tz>> for Postgres impl<Tz> Type<DateTime<Tz>> for Postgres
where where
Tz: TimeZone, Tz: TimeZone,
{ {
@ -37,25 +37,25 @@ where
} }
} }
impl HasSqlType<[NaiveTime]> for Postgres { impl Type<Postgres> for [NaiveTime] {
fn type_info() -> PgTypeInfo { fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::ARRAY_TIME) PgTypeInfo::new(TypeId::ARRAY_TIME)
} }
} }
impl HasSqlType<[NaiveDate]> for Postgres { impl Type<Postgres> for [NaiveDate] {
fn type_info() -> PgTypeInfo { fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::ARRAY_DATE) PgTypeInfo::new(TypeId::ARRAY_DATE)
} }
} }
impl HasSqlType<[NaiveDateTime]> for Postgres { impl Type<Postgres> for [NaiveDateTime] {
fn type_info() -> PgTypeInfo { fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::ARRAY_TIMESTAMP) PgTypeInfo::new(TypeId::ARRAY_TIMESTAMP)
} }
} }
impl<Tz> HasSqlType<[DateTime<Tz>]> for Postgres impl<Tz> Type<[DateTime<Tz>]> for Postgres
where where
Tz: TimeZone, Tz: TimeZone,
{ {

View File

@ -3,15 +3,15 @@ use crate::encode::Encode;
use crate::postgres::protocol::TypeId; use crate::postgres::protocol::TypeId;
use crate::postgres::types::PgTypeInfo; use crate::postgres::types::PgTypeInfo;
use crate::postgres::Postgres; use crate::postgres::Postgres;
use crate::types::HasSqlType; use crate::types::Type;
impl HasSqlType<f32> for Postgres { impl Type<Postgres> for f32 {
fn type_info() -> PgTypeInfo { fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::FLOAT4) PgTypeInfo::new(TypeId::FLOAT4)
} }
} }
impl HasSqlType<[f32]> for Postgres { impl Type<Postgres> for [f32] {
fn type_info() -> PgTypeInfo { fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::ARRAY_FLOAT4) PgTypeInfo::new(TypeId::ARRAY_FLOAT4)
} }
@ -31,13 +31,13 @@ impl Decode<Postgres> for f32 {
} }
} }
impl HasSqlType<f64> for Postgres { impl Type<Postgres> for f64 {
fn type_info() -> PgTypeInfo { fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::FLOAT8) PgTypeInfo::new(TypeId::FLOAT8)
} }
} }
impl HasSqlType<[f64]> for Postgres { impl Type<Postgres> for [f64] {
fn type_info() -> PgTypeInfo { fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::ARRAY_FLOAT8) PgTypeInfo::new(TypeId::ARRAY_FLOAT8)
} }

View File

@ -5,15 +5,15 @@ use crate::encode::Encode;
use crate::postgres::protocol::TypeId; use crate::postgres::protocol::TypeId;
use crate::postgres::types::PgTypeInfo; use crate::postgres::types::PgTypeInfo;
use crate::postgres::Postgres; use crate::postgres::Postgres;
use crate::types::HasSqlType; use crate::types::Type;
impl HasSqlType<i16> for Postgres { impl Type<Postgres> for i16 {
fn type_info() -> PgTypeInfo { fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::INT2) PgTypeInfo::new(TypeId::INT2)
} }
} }
impl HasSqlType<[i16]> for Postgres { impl Type<Postgres> for [i16] {
fn type_info() -> PgTypeInfo { fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::ARRAY_INT2) PgTypeInfo::new(TypeId::ARRAY_INT2)
} }
@ -31,13 +31,13 @@ impl Decode<Postgres> for i16 {
} }
} }
impl HasSqlType<i32> for Postgres { impl Type<Postgres> for i32 {
fn type_info() -> PgTypeInfo { fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::INT4) PgTypeInfo::new(TypeId::INT4)
} }
} }
impl HasSqlType<[i32]> for Postgres { impl Type<Postgres> for [i32] {
fn type_info() -> PgTypeInfo { fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::ARRAY_INT4) PgTypeInfo::new(TypeId::ARRAY_INT4)
} }
@ -55,13 +55,13 @@ impl Decode<Postgres> for i32 {
} }
} }
impl HasSqlType<i64> for Postgres { impl Type<Postgres> for i64 {
fn type_info() -> PgTypeInfo { fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::INT8) PgTypeInfo::new(TypeId::INT8)
} }
} }
impl HasSqlType<[i64]> for Postgres { impl Type<Postgres> for [i64] {
fn type_info() -> PgTypeInfo { fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::ARRAY_INT8) PgTypeInfo::new(TypeId::ARRAY_INT8)
} }

View File

@ -4,25 +4,25 @@ use crate::decode::{Decode, DecodeError};
use crate::encode::Encode; use crate::encode::Encode;
use crate::postgres::protocol::TypeId; use crate::postgres::protocol::TypeId;
use crate::postgres::types::PgTypeInfo; use crate::postgres::types::PgTypeInfo;
use crate::types::HasSqlType; use crate::types::Type;
use crate::Postgres; use crate::Postgres;
impl HasSqlType<str> for Postgres { impl Type<Postgres> for str {
fn type_info() -> PgTypeInfo { fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::TEXT) PgTypeInfo::new(TypeId::TEXT)
} }
} }
impl HasSqlType<[&'_ str]> for Postgres { impl Type<Postgres> for [&'_ str] {
fn type_info() -> PgTypeInfo { fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::ARRAY_TEXT) PgTypeInfo::new(TypeId::ARRAY_TEXT)
} }
} }
// TODO: Do we need [HasSqlType] on String here? // TODO: Do we need [HasSqlType] on String here?
impl HasSqlType<String> for Postgres { impl Type<Postgres> for String {
fn type_info() -> PgTypeInfo { fn type_info() -> PgTypeInfo {
<Self as HasSqlType<str>>::type_info() <str as Type<Postgres>>::type_info()
} }
} }

View File

@ -5,15 +5,15 @@ use crate::encode::Encode;
use crate::postgres::protocol::TypeId; use crate::postgres::protocol::TypeId;
use crate::postgres::types::PgTypeInfo; use crate::postgres::types::PgTypeInfo;
use crate::postgres::Postgres; use crate::postgres::Postgres;
use crate::types::HasSqlType; use crate::types::Type;
impl HasSqlType<Uuid> for Postgres { impl Type<Postgres> for Uuid {
fn type_info() -> PgTypeInfo { fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::UUID) PgTypeInfo::new(TypeId::UUID)
} }
} }
impl HasSqlType<[Uuid]> for Postgres { impl Type<Postgres> for [Uuid] {
fn type_info() -> PgTypeInfo { fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::ARRAY_UUID) PgTypeInfo::new(TypeId::ARRAY_UUID)
} }

View File

@ -4,68 +4,69 @@ use crate::cursor::Cursor;
use crate::database::{Database, HasCursor, HasRow}; use crate::database::{Database, HasCursor, HasRow};
use crate::encode::Encode; use crate::encode::Encode;
use crate::executor::{Execute, Executor}; use crate::executor::{Execute, Executor};
use crate::types::HasSqlType; use crate::types::Type;
use futures_core::stream::BoxStream; use futures_core::stream::BoxStream;
use futures_util::future::ready; use futures_util::future::ready;
use futures_util::TryFutureExt; use futures_util::TryFutureExt;
use futures_util::TryStreamExt; use futures_util::TryStreamExt;
use std::future::Future; use std::future::Future;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::mem;
/// Raw SQL query with bind parameters. Returned by [`query`]. /// Raw SQL query with bind parameters. Returned by [`query`].
pub struct Query<'a, DB, T = <DB as Database>::Arguments> pub struct Query<'q, DB, T = <DB as Database>::Arguments>
where where
DB: Database, DB: Database,
{ {
query: &'a str, query: &'q str,
arguments: T, arguments: T,
database: PhantomData<DB>, database: PhantomData<DB>,
} }
impl<'a, DB, P> Execute<'a, DB> for Query<'a, DB, P> impl<'q, DB, P> Execute<'q, DB> for Query<'q, DB, P>
where where
DB: Database, DB: Database,
P: IntoArguments<DB> + Send, P: IntoArguments<DB> + Send,
{ {
fn into_parts(self) -> (&'a str, Option<<DB as Database>::Arguments>) { fn into_parts(self) -> (&'q str, Option<<DB as Database>::Arguments>) {
(self.query, Some(self.arguments.into_arguments())) (self.query, Some(self.arguments.into_arguments()))
} }
} }
impl<'a, DB, P> Query<'a, DB, P> impl<'q, DB, P> Query<'q, DB, P>
where where
DB: Database, DB: Database,
P: IntoArguments<DB> + Send, P: IntoArguments<DB> + Send,
{ {
pub fn execute<'b, E>(self, executor: E) -> impl Future<Output = crate::Result<u64>> + 'b pub async fn execute<'e, E>(self, executor: E) -> crate::Result<u64>
where where
E: Executor<'b, Database = DB>, E: Executor<'e, Database = DB>,
'a: 'b, {
executor.execute(self).await
}
pub fn fetch<'e, E>(self, executor: E) -> <DB as HasCursor<'e, 'q>>::Cursor
where
E: Executor<'e, Database = DB>,
{ {
executor.execute(self) executor.execute(self)
} }
pub fn fetch<'b, E>(self, executor: E) -> <DB as HasCursor<'b>>::Cursor pub async fn fetch_optional<'e, E>(
where
E: Executor<'b, Database = DB>,
'a: 'b,
{
executor.execute(self)
}
pub async fn fetch_optional<'b, E>(
self, self,
executor: E, executor: E,
) -> crate::Result<Option<<DB as HasRow>::Row>> ) -> crate::Result<Option<<DB as HasRow<'e>>::Row>>
where where
E: Executor<'b, Database = DB>, E: Executor<'e, Database = DB>,
'q: 'e,
{ {
executor.execute(self).first().await executor.execute(self).first().await
} }
pub async fn fetch_one<'b, E>(self, executor: E) -> crate::Result<<DB as HasRow>::Row> pub async fn fetch_one<'e, E>(self, executor: E) -> crate::Result<<DB as HasRow<'e>>::Row>
where where
E: Executor<'b, Database = DB>, E: Executor<'e, Database = DB>,
'q: 'e,
{ {
self.fetch_optional(executor) self.fetch_optional(executor)
.and_then(|row| match row { .and_then(|row| match row {
@ -83,7 +84,7 @@ where
/// Bind a value for use with this SQL query. /// Bind a value for use with this SQL query.
pub fn bind<T>(mut self, value: T) -> Self pub fn bind<T>(mut self, value: T) -> Self
where where
DB: HasSqlType<T>, T: Type<DB>,
T: Encode<DB>, T: Encode<DB>,
{ {
self.arguments.add(value); self.arguments.add(value);

View File

@ -2,20 +2,17 @@
use crate::database::Database; use crate::database::Database;
use crate::decode::Decode; use crate::decode::Decode;
use crate::types::HasSqlType; use crate::types::Type;
pub trait RowIndex<R: ?Sized> pub trait RowIndex<'c, R: ?Sized>
where where
R: Row, R: Row<'c>,
{ {
fn try_get<T>(&self, row: &R) -> crate::Result<T> fn try_get_raw(self, row: &'c R) -> crate::Result<Option<&'c [u8]>>;
where
R::Database: HasSqlType<T>,
T: Decode<R::Database>;
} }
/// Represents a single row of the result set. /// Represents a single row of the result set.
pub trait Row: Unpin + Send + 'static { pub trait Row<'c>: Unpin + Send {
type Database: Database + ?Sized; type Database: Database + ?Sized;
/// Returns `true` if the row contains no values. /// Returns `true` if the row contains no values.
@ -26,18 +23,34 @@ pub trait Row: Unpin + Send + 'static {
/// Returns the number of values in the row. /// Returns the number of values in the row.
fn len(&self) -> usize; fn len(&self) -> usize;
/// Returns the value at the `index`; can either be an integer ordinal or a column name. fn get<T, I>(&'c self, index: I) -> T
fn get<T, I>(&self, index: I) -> T
where where
Self::Database: HasSqlType<T>, T: Type<Self::Database>,
I: RowIndex<Self>, I: RowIndex<'c, Self>,
T: Decode<Self::Database>; T: Decode<Self::Database>,
{
// todo: use expect with a proper message
self.try_get(index).unwrap()
}
fn try_get<T, I>(&'c self, index: I) -> crate::Result<T>
where
T: Type<Self::Database>,
I: RowIndex<'c, Self>,
T: Decode<Self::Database>,
{
Ok(Decode::decode_nullable(self.try_get_raw(index)?)?)
}
fn try_get_raw<'i, I>(&'c self, index: I) -> crate::Result<Option<&'c [u8]>>
where
I: RowIndex<'c, Self> + 'i;
} }
/// A **record** that can be built from a row returned from by the database. /// A **record** that can be built from a row returned from by the database.
pub trait FromRow<R> pub trait FromRow<'a, R>
where where
R: Row, R: Row<'a>,
{ {
fn from_row(row: R) -> Self; fn from_row(row: R) -> Self;
} }

View File

@ -4,6 +4,7 @@ use futures_core::future::BoxFuture;
use crate::connection::Connection; use crate::connection::Connection;
use crate::database::HasCursor; use crate::database::HasCursor;
use crate::describe::Describe;
use crate::executor::{Execute, Executor}; use crate::executor::{Execute, Executor};
use crate::runtime::spawn; use crate::runtime::spawn;
use crate::Database; use crate::Database;
@ -19,10 +20,11 @@ where
depth: u32, depth: u32,
} }
impl<T> Transaction<T> impl<DB, T> Transaction<T>
where where
T: Connection, T: Connection<Database = DB>,
T: Executor<'static>, DB: Database,
T: Executor<'static, Database = DB>,
{ {
pub(crate) async fn new(depth: u32, mut inner: T) -> crate::Result<Self> { pub(crate) async fn new(depth: u32, mut inner: T) -> crate::Result<Self> {
if depth == 0 { if depth == 0 {
@ -98,10 +100,11 @@ where
} }
} }
impl<T> Connection for Transaction<T> impl<T, DB> Connection for Transaction<T>
where where
T: Connection, T: Connection<Database = DB>,
T: Executor<'static>, DB: Database,
T: Executor<'static, Database = DB>,
{ {
type Database = <T as Connection>::Database; type Database = <T as Connection>::Database;
@ -109,9 +112,23 @@ where
fn close(self) -> BoxFuture<'static, crate::Result<()>> { fn close(self) -> BoxFuture<'static, crate::Result<()>> {
Box::pin(async move { self.rollback().await?.close().await }) Box::pin(async move { self.rollback().await?.close().await })
} }
#[inline]
fn ping(&mut self) -> BoxFuture<crate::Result<()>> {
Box::pin(self.deref_mut().ping())
}
#[doc(hidden)]
#[inline]
fn describe<'e, 'q: 'e>(
&'e mut self,
query: &'q str,
) -> BoxFuture<'e, crate::Result<Describe<Self::Database>>> {
Box::pin(self.deref_mut().describe(query))
}
} }
impl<'a, DB, T> Executor<'a> for &'a mut Transaction<T> impl<'c, DB, T> Executor<'c> for &'c mut Transaction<T>
where where
DB: Database, DB: Database,
T: Connection<Database = DB>, T: Connection<Database = DB>,
@ -119,19 +136,19 @@ where
{ {
type Database = <T as Connection>::Database; type Database = <T as Connection>::Database;
fn execute<'b, E>(self, query: E) -> <<T as Connection>::Database as HasCursor<'a>>::Cursor fn execute<'q, E>(self, query: E) -> <<T as Connection>::Database as HasCursor<'c, 'q>>::Cursor
where where
E: Execute<'b, Self::Database>, E: Execute<'q, Self::Database>,
{ {
(**self).execute_by_ref(query) (**self).execute_by_ref(query)
} }
fn execute_by_ref<'b, 'c, E>( fn execute_by_ref<'q, 'e, E>(
&'c mut self, &'e mut self,
query: E, query: E,
) -> <Self::Database as HasCursor<'c>>::Cursor ) -> <Self::Database as HasCursor<'e, 'q>>::Cursor
where where
E: Execute<'b, Self::Database>, E: Execute<'q, Self::Database>,
{ {
(**self).execute_by_ref(query) (**self).execute_by_ref(query)
} }

View File

@ -21,29 +21,34 @@ pub trait TypeInfo: Debug + Display + Clone {
} }
/// Indicates that a SQL type is supported for a database. /// Indicates that a SQL type is supported for a database.
pub trait HasSqlType<T: ?Sized>: Database { pub trait Type<DB>
where
DB: Database,
{
/// Returns the canonical type information on the database for the type `T`. /// Returns the canonical type information on the database for the type `T`.
fn type_info() -> Self::TypeInfo; fn type_info() -> DB::TypeInfo;
} }
// For references to types in Rust, the underlying SQL type information // For references to types in Rust, the underlying SQL type information
// is equivalent // is equivalent
impl<T: ?Sized, DB> HasSqlType<&'_ T> for DB impl<T: ?Sized, DB> Type<DB> for &'_ T
where where
DB: HasSqlType<T>, DB: Database,
T: Type<DB>,
{ {
fn type_info() -> Self::TypeInfo { fn type_info() -> DB::TypeInfo {
<DB as HasSqlType<T>>::type_info() <T as Type<DB>>::type_info()
} }
} }
// For optional types in Rust, the underlying SQL type information // For optional types in Rust, the underlying SQL type information
// is equivalent // is equivalent
impl<T, DB> HasSqlType<Option<T>> for DB impl<T, DB> Type<DB> for Option<T>
where where
DB: HasSqlType<T>, DB: Database,
T: Type<DB>,
{ {
fn type_info() -> Self::TypeInfo { fn type_info() -> DB::TypeInfo {
<DB as HasSqlType<T>>::type_info() <T as Type<DB>>::type_info()
} }
} }

View File

@ -32,7 +32,7 @@ macro_rules! impl_database_ext {
$( $(
// `if` statements cannot have attributes but these can // `if` statements cannot have attributes but these can
$(#[$meta])? $(#[$meta])?
_ if sqlx::types::TypeInfo::compatible(&<$database as sqlx::types::HasSqlType<$ty>>::type_info(), &info) => Some(input_ty!($ty $(, $input)?)), _ if sqlx::types::TypeInfo::compatible(&<$ty as sqlx::types::Type<$database>>::type_info(), &info) => Some(input_ty!($ty $(, $input)?)),
)* )*
_ => None _ => None
} }
@ -42,7 +42,7 @@ macro_rules! impl_database_ext {
match () { match () {
$( $(
$(#[$meta])? $(#[$meta])?
_ if sqlx::types::TypeInfo::compatible(&<$database as sqlx::types::HasSqlType<$ty>>::type_info(), &info) => return Some(stringify!($ty)), _ if sqlx::types::TypeInfo::compatible(&<$ty as sqlx::types::Type<$database>>::type_info(), &info) => return Some(stringify!($ty)),
)* )*
_ => None _ => None
} }

View File

@ -12,7 +12,7 @@ pub use sqlx_core::{arguments, describe, error, pool, row, types};
// Types // Types
pub use sqlx_core::{ pub use sqlx_core::{
Connect, Connection, Database, Error, Executor, FromRow, Pool, Query, QueryAs, Result, Row, Connect, Connection, Cursor, Database, Error, Executor, FromRow, Pool, Query, QueryAs, Result, Row,
Transaction, Transaction,
}; };

View File

@ -1,59 +1,59 @@
use sqlx::{postgres::PgConnection, Connection as _, Row}; // use sqlx::{postgres::PgConnection, Connect as _, Connection as _, Row};
//
async fn connect() -> anyhow::Result<PgConnection> { // async fn connect() -> anyhow::Result<PgConnection> {
Ok(PgConnection::open(dotenv::var("DATABASE_URL")?).await?) // Ok(PgConnection::connect(dotenv::var("DATABASE_URL")?).await?)
} // }
//
macro_rules! test { // macro_rules! test {
($name:ident: $ty:ty: $($text:literal == $value:expr),+) => { // ($name:ident: $ty:ty: $($text:literal == $value:expr),+) => {
#[cfg_attr(feature = "runtime-async-std", async_std::test)] // #[cfg_attr(feature = "runtime-async-std", async_std::test)]
#[cfg_attr(feature = "runtime-tokio", tokio::test)] // #[cfg_attr(feature = "runtime-tokio", tokio::test)]
async fn $name () -> anyhow::Result<()> { // async fn $name () -> anyhow::Result<()> {
let mut conn = connect().await?; // let mut conn = connect().await?;
//
$( // $(
let row = sqlx::query(&format!("SELECT {} = $1, $1 as _1", $text)) // let row = sqlx::query(&format!("SELECT {} = $1, $1 as _1", $text))
.bind($value) // .bind($value)
.fetch_one(&mut conn) // .fetch_one(&mut conn)
.await?; // .await?;
//
assert!(row.get::<bool, _>(0)); // assert!(row.get::<bool, _>(0));
assert!($value == row.get::<$ty, _>("_1")); // assert!($value == row.get::<$ty, _>("_1"));
)+ // )+
//
Ok(()) // Ok(())
} // }
} // }
} // }
//
test!(postgres_bool: bool: "false::boolean" == false, "true::boolean" == true); // test!(postgres_bool: bool: "false::boolean" == false, "true::boolean" == true);
//
test!(postgres_smallint: i16: "821::smallint" == 821_i16); // test!(postgres_smallint: i16: "821::smallint" == 821_i16);
test!(postgres_int: i32: "94101::int" == 94101_i32); // test!(postgres_int: i32: "94101::int" == 94101_i32);
test!(postgres_bigint: i64: "9358295312::bigint" == 9358295312_i64); // test!(postgres_bigint: i64: "9358295312::bigint" == 9358295312_i64);
//
test!(postgres_real: f32: "9419.122::real" == 9419.122_f32); // test!(postgres_real: f32: "9419.122::real" == 9419.122_f32);
test!(postgres_double: f64: "939399419.1225182::double precision" == 939399419.1225182_f64); // test!(postgres_double: f64: "939399419.1225182::double precision" == 939399419.1225182_f64);
//
test!(postgres_text: String: "'this is foo'" == "this is foo", "''" == ""); // test!(postgres_text: String: "'this is foo'" == "this is foo", "''" == "");
//
#[cfg_attr(feature = "runtime-async-std", async_std::test)] // #[cfg_attr(feature = "runtime-async-std", async_std::test)]
#[cfg_attr(feature = "runtime-tokio", tokio::test)] // #[cfg_attr(feature = "runtime-tokio", tokio::test)]
async fn postgres_bytes() -> anyhow::Result<()> { // async fn postgres_bytes() -> anyhow::Result<()> {
let mut conn = connect().await?; // let mut conn = connect().await?;
//
let value = b"Hello, World"; // let value = b"Hello, World";
//
let row = sqlx::query("SELECT E'\\\\x48656c6c6f2c20576f726c64' = $1, $1") // let row = sqlx::query("SELECT E'\\\\x48656c6c6f2c20576f726c64' = $1, $1")
.bind(&value[..]) // .bind(&value[..])
.fetch_one(&mut conn) // .fetch_one(&mut conn)
.await?; // .await?;
//
assert!(row.get::<bool, _>(0)); // assert!(row.get::<bool, _>(0));
//
let output: Vec<u8> = row.get(1); // let output: Vec<u8> = row.get(1);
//
assert_eq!(&value[..], &*output); // assert_eq!(&value[..], &*output);
//
Ok(()) // Ok(())
} // }

View File

@ -1,5 +1,5 @@
use futures::TryStreamExt; use futures::TryStreamExt;
use sqlx::{postgres::PgConnection, Connection as _, Executor as _, Row as _}; use sqlx::{postgres::PgConnection, Connect, Connection, Executor, Row};
use sqlx_core::postgres::PgPool; use sqlx_core::postgres::PgPool;
use std::time::Duration; use std::time::Duration;
@ -17,58 +17,40 @@ async fn it_connects() -> anyhow::Result<()> {
Ok(()) Ok(())
} }
#[cfg_attr(feature = "runtime-async-std", async_std::test)] // #[cfg_attr(feature = "runtime-async-std", async_std::test)]
#[cfg_attr(feature = "runtime-tokio", tokio::test)] // #[cfg_attr(feature = "runtime-tokio", tokio::test)]
async fn it_executes() -> anyhow::Result<()> { // async fn it_executes() -> anyhow::Result<()> {
let mut conn = connect().await?; // let mut conn = connect().await?;
//
let _ = conn // let _ = conn
.send( // .send(
r#" // r#"
CREATE TEMPORARY TABLE users (id INTEGER PRIMARY KEY); // CREATE TEMPORARY TABLE users (id INTEGER PRIMARY KEY);
"#, // "#,
) // )
.await?; // .await?;
//
for index in 1..=10_i32 { // for index in 1..=10_i32 {
let cnt = sqlx::query("INSERT INTO users (id) VALUES ($1)") // let cnt = sqlx::query("INSERT INTO users (id) VALUES ($1)")
.bind(index) // .bind(index)
.execute(&mut conn) // .execute(&mut conn)
.await?; // .await?;
//
assert_eq!(cnt, 1); // assert_eq!(cnt, 1);
} // }
//
let sum: i32 = sqlx::query("SELECT id FROM users") // let sum: i32 = sqlx::query("SELECT id FROM users")
.fetch(&mut conn) // .fetch(&mut conn)
.try_fold( // .try_fold(
0_i32, // 0_i32,
|acc, x| async move { Ok(acc + x.get::<i32, _>("id")) }, // |acc, x| async move { Ok(acc + x.get::<i32, _>("id")) },
) // )
.await?; // .await?;
//
assert_eq!(sum, 55); // assert_eq!(sum, 55);
//
Ok(()) // Ok(())
} // }
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
async fn it_remains_stable_issue_30() -> anyhow::Result<()> {
let mut conn = connect().await?;
// This tests the internal buffer wrapping around at the end
// Specifically: https://github.com/launchbadge/sqlx/issues/30
let rows = sqlx::query("SELECT i, random()::text FROM generate_series(1, 1000) as i")
.fetch_all(&mut conn)
.await?;
assert_eq!(rows.len(), 1000);
assert_eq!(rows[rows.len() - 1].get::<i32, _>(0), 1000);
Ok(())
}
// https://github.com/launchbadge/sqlx/issues/104 // https://github.com/launchbadge/sqlx/issues/104
#[cfg_attr(feature = "runtime-async-std", async_std::test)] #[cfg_attr(feature = "runtime-async-std", async_std::test)]
@ -122,7 +104,7 @@ async fn pool_smoke_test() -> anyhow::Result<()> {
let pool = pool.clone(); let pool = pool.clone();
spawn(async move { spawn(async move {
loop { loop {
if let Err(e) = sqlx::query("select 1 + 1").fetch_one(&mut &pool).await { if let Err(e) = sqlx::query("select 1 + 1").fetch_one(&pool).await {
eprintln!("pool task {} dying due to {}", i, e); eprintln!("pool task {} dying due to {}", i, e);
break; break;
} }
@ -159,5 +141,5 @@ async fn pool_smoke_test() -> anyhow::Result<()> {
async fn connect() -> anyhow::Result<PgConnection> { async fn connect() -> anyhow::Result<PgConnection> {
let _ = dotenv::dotenv(); let _ = dotenv::dotenv();
let _ = env_logger::try_init(); let _ = env_logger::try_init();
Ok(PgConnection::open(dotenv::var("DATABASE_URL")?).await?) Ok(PgConnection::connect(dotenv::var("DATABASE_URL")?).await?)
} }