postgres: rewrite protocol in more iterative and lazy fashion

This commit is contained in:
Ryan Leckey 2020-02-19 08:10:27 -08:00
parent 3795d15e1c
commit a374c18a18
60 changed files with 1586 additions and 931 deletions

9
Cargo.lock generated
View File

@ -1526,6 +1526,15 @@ dependencies = [
"uuid 0.8.1 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "sqlx-example-postgres-basic"
version = "0.1.0"
dependencies = [
"anyhow 1.0.26 (registry+https://github.com/rust-lang/crates.io-index)",
"async-std 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
"sqlx 0.2.5",
]
[[package]]
name = "sqlx-example-realworld-postgres"
version = "0.1.0"

View File

@ -3,6 +3,7 @@ members = [
".",
"sqlx-core",
"sqlx-macros",
"examples/postgres/basic",
"examples/realworld-postgres",
"examples/todos-postgres",
]

View File

@ -0,0 +1,10 @@
[package]
workspace = "../../.."
name = "sqlx-example-postgres-basic"
version = "0.1.0"
edition = "2018"
[dependencies]
async-std = { version = "1", features = [ "attributes" ] }
anyhow = "1"
sqlx = { path = "../../..", features = [ "postgres" ] }

View File

@ -0,0 +1,25 @@
use sqlx::{Connect, Connection, Cursor, Executor, PgConnection, Row};
use std::convert::TryInto;
use std::time::Instant;
#[async_std::main]
async fn main() -> anyhow::Result<()> {
let mut conn = PgConnection::connect("postgres://").await?;
let mut rows = sqlx::query("SELECT definition FROM pg_database")
.execute(&mut conn)
.await?;
// let start = Instant::now();
// while let Some(row) = cursor.next().await? {
// // let raw = row.try_get(0)?.unwrap();
//
// // println!("hai: {:?}", raw);
// }
println!("?? = {}", rows);
// conn.close().await?;
Ok(())
}

View File

@ -1,2 +0,0 @@
select * from (select (1) as id, 'Herp Derpinson' as name) accounts
where id = ?

View File

@ -2,7 +2,7 @@
use crate::database::Database;
use crate::encode::Encode;
use crate::types::HasSqlType;
use crate::types::Type;
/// A tuple of arguments to be sent to the database.
pub trait Arguments: Send + Sized + Default + 'static {
@ -15,7 +15,7 @@ pub trait Arguments: Send + Sized + Default + 'static {
/// Add the value to the end of the arguments.
fn add<T>(&mut self, value: T)
where
Self::Database: HasSqlType<T>,
T: Type<Self::Database>,
T: Encode<Self::Database>;
}

View File

@ -1,10 +1,14 @@
use std::convert::TryInto;
use std::ops::{Deref, DerefMut};
use futures_core::future::BoxFuture;
use futures_util::TryFutureExt;
use crate::database::Database;
use crate::describe::Describe;
use crate::executor::Executor;
use crate::pool::{Pool, PoolConnection};
use crate::url::Url;
use futures_core::future::BoxFuture;
use futures_util::TryFutureExt;
use std::convert::TryInto;
/// Represents a single database connection rather than a pool of database connections.
///
@ -20,20 +24,13 @@ where
fn close(self) -> BoxFuture<'static, crate::Result<()>>;
/// Verifies a connection to the database is still alive.
fn ping(&mut self) -> BoxFuture<crate::Result<()>>
where
for<'a> &'a mut Self: Executor<'a>,
{
Box::pin((&mut *self).execute("SELECT 1").map_ok(|_| ()))
}
fn ping(&mut self) -> BoxFuture<crate::Result<()>>;
#[doc(hidden)]
fn describe<'e, 'q: 'e>(
&'e mut self,
query: &'q str,
) -> BoxFuture<'e, crate::Result<Describe<Self::Database>>> {
todo!("make this a required function");
}
) -> BoxFuture<'e, crate::Result<Describe<Self::Database>>>;
}
/// Represents a type that can directly establish a new connection.
@ -44,3 +41,125 @@ pub trait Connect: Connection {
T: TryInto<Url, Error = crate::Error>,
Self: Sized;
}
mod internal {
pub enum MaybeOwnedConnection<'c, C>
where
C: super::Connect,
{
Borrowed(&'c mut C),
Owned(super::PoolConnection<C>),
}
pub enum ConnectionSource<'c, C>
where
C: super::Connect,
{
Empty,
Connection(MaybeOwnedConnection<'c, C>),
Pool(super::Pool<C>),
}
}
pub(crate) use self::internal::{ConnectionSource, MaybeOwnedConnection};
impl<'c, C> MaybeOwnedConnection<'c, C>
where
C: Connect,
{
pub(crate) fn borrow(&mut self) -> MaybeOwnedConnection<'_, C> {
match self {
MaybeOwnedConnection::Borrowed(conn) => MaybeOwnedConnection::Borrowed(&mut *conn),
MaybeOwnedConnection::Owned(ref mut conn) => MaybeOwnedConnection::Borrowed(conn),
}
}
}
impl<'c, C, DB> ConnectionSource<'c, C>
where
C: Connect<Database = DB>,
DB: Database<Connection = C>,
{
pub(crate) async fn resolve_by_ref(&mut self) -> crate::Result<MaybeOwnedConnection<'_, C>> {
if let ConnectionSource::Pool(pool) = self {
*self =
ConnectionSource::Connection(MaybeOwnedConnection::Owned(pool.acquire().await?));
}
Ok(match self {
ConnectionSource::Empty => panic!("`PgCursor` must not be used after being polled"),
ConnectionSource::Connection(conn) => conn.borrow(),
ConnectionSource::Pool(_) => unreachable!(),
})
}
pub(crate) async fn resolve(mut self) -> crate::Result<MaybeOwnedConnection<'c, C>> {
if let ConnectionSource::Pool(pool) = self {
self = ConnectionSource::Connection(MaybeOwnedConnection::Owned(pool.acquire().await?));
}
Ok(self.into_connection())
}
pub(crate) fn into_connection(self) -> MaybeOwnedConnection<'c, C> {
match self {
ConnectionSource::Connection(conn) => conn,
ConnectionSource::Empty | ConnectionSource::Pool(_) => {
panic!("`PgCursor` must not be used after being polled");
}
}
}
}
impl<C> Default for ConnectionSource<'_, C>
where
C: Connect,
{
fn default() -> Self {
ConnectionSource::Empty
}
}
impl<'c, C> From<&'c mut C> for MaybeOwnedConnection<'c, C>
where
C: Connect,
{
fn from(conn: &'c mut C) -> Self {
MaybeOwnedConnection::Borrowed(conn)
}
}
impl<'c, C> From<PoolConnection<C>> for MaybeOwnedConnection<'c, C>
where
C: Connect,
{
fn from(conn: PoolConnection<C>) -> Self {
MaybeOwnedConnection::Owned(conn)
}
}
impl<'c, C> Deref for MaybeOwnedConnection<'c, C>
where
C: Connect,
{
type Target = C;
fn deref(&self) -> &Self::Target {
match self {
MaybeOwnedConnection::Borrowed(conn) => conn,
MaybeOwnedConnection::Owned(conn) => conn,
}
}
}
impl<'c, C> DerefMut for MaybeOwnedConnection<'c, C>
where
C: Connect,
{
fn deref_mut(&mut self) -> &mut Self::Target {
match self {
MaybeOwnedConnection::Borrowed(conn) => conn,
MaybeOwnedConnection::Owned(conn) => conn,
}
}
}

View File

@ -3,7 +3,10 @@ use std::future::Future;
use futures_core::future::BoxFuture;
use futures_core::stream::BoxStream;
use crate::connection::MaybeOwnedConnection;
use crate::database::{Database, HasRow};
use crate::executor::Execute;
use crate::{Connect, Pool};
/// Represents a result set, which is generated by executing a query against the database.
///
@ -13,7 +16,7 @@ use crate::database::{Database, HasRow};
/// Initially the `Cursor` is positioned before the first row. The `next` method moves the cursor
/// to the next row, and because it returns `None` when there are no more rows, it can be used
/// in a `while` loop to iterate through all returned rows.
pub trait Cursor<'a>
pub trait Cursor<'c, 'q>
where
Self: Send,
// `.await`-ing a cursor will return the affected rows from the query
@ -21,16 +24,59 @@ where
{
type Database: Database;
/// Fetch the first row in the result. Returns `None` if no row is present.
///
/// Returns `Error::MoreThanOneRow` if more than one row is in the result.
fn first(self) -> BoxFuture<'a, crate::Result<Option<<Self::Database as HasRow>::Row>>>;
// Construct the [Cursor] from a [Pool]
// Meant for internal use only
// TODO: Anyone have any better ideas on how to instantiate cursors generically from a pool?
#[doc(hidden)]
fn from_pool<E>(pool: &Pool<<Self::Database as Database>::Connection>, query: E) -> Self
where
Self: Sized,
E: Execute<'q, Self::Database>;
#[doc(hidden)]
fn from_connection<E, C>(conn: C, query: E) -> Self
where
Self: Sized,
<Self::Database as Database>::Connection: Connect,
// MaybeOwnedConnection<'c, <Self::Database as Database>::Connection>:
// Connect<Database = Self::Database>,
C: Into<MaybeOwnedConnection<'c, <Self::Database as Database>::Connection>>,
E: Execute<'q, Self::Database>;
#[doc(hidden)]
fn first(self) -> BoxFuture<'c, crate::Result<Option<<Self::Database as HasRow<'c>>::Row>>>
where
'q: 'c;
/// Fetch the next row in the result. Returns `None` if there are no more rows.
fn next(&mut self) -> BoxFuture<crate::Result<Option<<Self::Database as HasRow>::Row>>>;
/// Map the `Row`s in this result to a different type, returning a [`Stream`] of the results.
fn map<T, F>(self, f: F) -> BoxStream<'a, crate::Result<T>>
fn map<T, F>(self, f: F) -> BoxStream<'c, crate::Result<T>>
where
F: Fn(<Self::Database as HasRow>::Row) -> T;
F: MapRowFn<Self::Database, T>,
T: 'c + Send + Unpin,
'q: 'c;
}
pub trait MapRowFn<DB, T>
where
Self: Send + Sync + 'static,
DB: Database,
DB: for<'c> HasRow<'c>,
{
fn call(&self, row: <DB as HasRow>::Row) -> T;
}
impl<DB, T, F> MapRowFn<DB, T> for F
where
DB: Database,
DB: for<'c> HasRow<'c>,
F: Send + Sync + 'static,
F: Fn(<DB as HasRow>::Row) -> T,
{
#[inline(always)]
fn call(&self, row: <DB as HasRow>::Row) -> T {
self(row)
}
}

View File

@ -13,9 +13,9 @@ use crate::types::TypeInfo;
pub trait Database
where
Self: Sized + 'static,
Self: HasRow<Database = Self>,
Self: for<'a> HasRow<'a, Database = Self>,
Self: for<'a> HasRawValue<'a>,
Self: for<'a> HasCursor<'a, Database = Self>,
Self: for<'c, 'q> HasCursor<'c, 'q, Database = Self>,
{
/// The concrete `Connection` implementation for this database.
type Connection: Connection<Database = Self>;
@ -34,14 +34,14 @@ pub trait HasRawValue<'a> {
type RawValue;
}
pub trait HasCursor<'a> {
pub trait HasCursor<'c, 'q> {
type Database: Database;
type Cursor: Cursor<'a, Database = Self::Database>;
type Cursor: Cursor<'c, 'q, Database = Self::Database>;
}
pub trait HasRow {
pub trait HasRow<'a> {
type Database: Database;
type Row: Row<Database = Self::Database>;
type Row: Row<'a, Database = Self::Database>;
}

View File

@ -4,7 +4,7 @@ use std::error::Error as StdError;
use std::fmt::{self, Display};
use crate::database::Database;
use crate::types::HasSqlType;
use crate::types::Type;
pub enum DecodeError {
/// An unexpected `NULL` was encountered while decoding.
@ -40,7 +40,8 @@ where
impl<T, DB> Decode<DB> for Option<T>
where
DB: Database + HasSqlType<T>,
DB: Database,
T: Type<DB>,
T: Decode<DB>,
{
fn decode(buf: &[u8]) -> Result<Self, DecodeError> {

View File

@ -1,7 +1,7 @@
//! Types and traits for encoding values to the database.
use crate::database::Database;
use crate::types::HasSqlType;
use crate::types::Type;
use std::mem;
/// The return type of [Encode::encode].
@ -36,7 +36,8 @@ where
impl<T: ?Sized, DB> Encode<DB> for &'_ T
where
DB: Database + HasSqlType<T>,
DB: Database,
T: Type<DB>,
T: Encode<DB>,
{
fn encode(&self, buf: &mut Vec<u8>) {
@ -54,7 +55,8 @@ where
impl<T, DB> Encode<DB> for Option<T>
where
DB: Database + HasSqlType<T>,
DB: Database,
T: Type<DB>,
T: Encode<DB>,
{
fn encode(&self, buf: &mut Vec<u8>) {

View File

@ -14,7 +14,7 @@ use futures_util::TryStreamExt;
/// Implementations are provided for [`&Pool`](struct.Pool.html),
/// [`&mut PoolConnection`](struct.PoolConnection.html),
/// and [`&mut Connection`](trait.Connection.html).
pub trait Executor<'a>
pub trait Executor<'c>
where
Self: Send,
{
@ -22,18 +22,18 @@ where
type Database: Database;
/// Executes a query that may or may not return a result set.
fn execute<'b, E>(self, query: E) -> <Self::Database as HasCursor<'a>>::Cursor
fn execute<'q, E>(self, query: E) -> <Self::Database as HasCursor<'c, 'q>>::Cursor
where
E: Execute<'b, Self::Database>;
E: Execute<'q, Self::Database>;
#[doc(hidden)]
fn execute_by_ref<'b, E>(&mut self, query: E) -> <Self::Database as HasCursor<'_>>::Cursor
fn execute_by_ref<'b, E>(&mut self, query: E) -> <Self::Database as HasCursor<'_, 'b>>::Cursor
where
E: Execute<'b, Self::Database>;
}
/// A type that may be executed against a database connection.
pub trait Execute<'a, DB>
pub trait Execute<'q, DB>
where
DB: Database,
{
@ -43,15 +43,15 @@ where
/// prepare the query. Returning `Some(Default::default())` is an empty arguments object that
/// will be prepared (and cached) before execution.
#[doc(hidden)]
fn into_parts(self) -> (&'a str, Option<DB::Arguments>);
fn into_parts(self) -> (&'q str, Option<DB::Arguments>);
}
impl<'a, DB> Execute<'a, DB> for &'a str
impl<'q, DB> Execute<'q, DB> for &'q str
where
DB: Database,
{
#[inline]
fn into_parts(self) -> (&'a str, Option<DB::Arguments>) {
fn into_parts(self) -> (&'q str, Option<DB::Arguments>) {
(self, None)
}
}

View File

@ -35,6 +35,11 @@ where
}
}
#[inline]
pub fn buffer(&self) -> &[u8] {
&self.rbuf[self.rbuf_rindex..]
}
#[inline]
pub fn buffer_mut(&mut self) -> &mut Vec<u8> {
&mut self.wbuf
@ -61,7 +66,14 @@ where
self.rbuf_rindex += cnt;
}
pub async fn peek(&mut self, cnt: usize) -> io::Result<Option<&[u8]>> {
pub async fn peek(&mut self, cnt: usize) -> io::Result<&[u8]> {
self.try_peek(cnt)
.await
.transpose()
.ok_or(io::ErrorKind::ConnectionAborted)?
}
pub async fn try_peek(&mut self, cnt: usize) -> io::Result<Option<&[u8]>> {
loop {
// Reaching end-of-file (read 0 bytes) will continuously
// return None from all future calls to read

View File

@ -1,5 +1,5 @@
#![recursion_limit = "256"]
#![forbid(unsafe_code)]
#![allow(unused)]
#![cfg_attr(docsrs, feature(doc_cfg))]
#[macro_use]
@ -52,7 +52,7 @@ pub use error::{Error, Result};
pub use connection::{Connect, Connection};
pub use cursor::Cursor;
pub use executor::Executor;
pub use executor::{Execute, Executor};
pub use query::{query, Query};
pub use transaction::Transaction;
@ -71,3 +71,7 @@ pub use mysql::MySql;
#[cfg_attr(docsrs, doc(cfg(feature = "postgres")))]
#[doc(inline)]
pub use postgres::Postgres;
// Named Lifetimes:
// 'c: connection
// 'q: query string (and arguments)

View File

@ -2,7 +2,7 @@ use crate::arguments::Arguments;
use crate::encode::{Encode, IsNull};
use crate::mysql::types::MySqlTypeInfo;
use crate::mysql::MySql;
use crate::types::HasSqlType;
use crate::types::Type;
#[derive(Default)]
pub struct MySqlArguments {
@ -27,10 +27,10 @@ impl Arguments for MySqlArguments {
fn add<T>(&mut self, value: T)
where
Self::Database: HasSqlType<T>,
Self::Database: Type<T>,
T: Encode<Self::Database>,
{
let type_id = <MySql as HasSqlType<T>>::type_info();
let type_id = <MySql as Type<T>>::type_info();
let index = self.param_types.len();
self.param_types.push(type_id);

View File

@ -5,7 +5,7 @@ use crate::decode::Decode;
use crate::mysql::protocol;
use crate::mysql::MySql;
use crate::row::{Row, RowIndex};
use crate::types::HasSqlType;
use crate::types::Type;
pub struct MySqlRow {
pub(super) row: protocol::Row,
@ -21,7 +21,7 @@ impl Row for MySqlRow {
fn get<T, I>(&self, index: I) -> T
where
Self::Database: HasSqlType<T>,
Self::Database: Type<T>,
I: RowIndex<Self>,
T: Decode<Self::Database>,
{
@ -32,7 +32,7 @@ impl Row for MySqlRow {
impl RowIndex<MySqlRow> for usize {
fn try_get<T>(&self, row: &MySqlRow) -> crate::Result<T>
where
<MySqlRow as Row>::Database: HasSqlType<T>,
<MySqlRow as Row>::Database: Type<T>,
T: Decode<<MySqlRow as Row>::Database>,
{
Ok(Decode::decode_nullable(row.row.get(*self))?)
@ -42,7 +42,7 @@ impl RowIndex<MySqlRow> for usize {
impl RowIndex<MySqlRow> for &'_ str {
fn try_get<T>(&self, row: &MySqlRow) -> crate::Result<T>
where
<MySqlRow as Row>::Database: HasSqlType<T>,
<MySqlRow as Row>::Database: Type<T>,
T: Decode<<MySqlRow as Row>::Database>,
{
let index = row

View File

@ -3,9 +3,9 @@ use crate::encode::Encode;
use crate::mysql::protocol::TypeId;
use crate::mysql::types::MySqlTypeInfo;
use crate::mysql::MySql;
use crate::types::HasSqlType;
use crate::types::Type;
impl HasSqlType<bool> for MySql {
impl Type<bool> for MySql {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::TINY_INT)
}

View File

@ -6,9 +6,9 @@ use crate::mysql::io::{BufExt, BufMutExt};
use crate::mysql::protocol::TypeId;
use crate::mysql::types::MySqlTypeInfo;
use crate::mysql::MySql;
use crate::types::HasSqlType;
use crate::types::Type;
impl HasSqlType<[u8]> for MySql {
impl Type<[u8]> for MySql {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo {
id: TypeId::TEXT,
@ -19,9 +19,9 @@ impl HasSqlType<[u8]> for MySql {
}
}
impl HasSqlType<Vec<u8>> for MySql {
impl Type<Vec<u8>> for MySql {
fn type_info() -> MySqlTypeInfo {
<Self as HasSqlType<[u8]>>::type_info()
<Self as Type<[u8]>>::type_info()
}
}

View File

@ -9,9 +9,9 @@ use crate::io::{Buf, BufMut};
use crate::mysql::protocol::TypeId;
use crate::mysql::types::MySqlTypeInfo;
use crate::mysql::MySql;
use crate::types::HasSqlType;
use crate::types::Type;
impl HasSqlType<DateTime<Utc>> for MySql {
impl Type<DateTime<Utc>> for MySql {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::TIMESTAMP)
}
@ -31,7 +31,7 @@ impl Decode<MySql> for DateTime<Utc> {
}
}
impl HasSqlType<NaiveTime> for MySql {
impl Type<NaiveTime> for MySql {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::TIME)
}
@ -80,7 +80,7 @@ impl Decode<MySql> for NaiveTime {
}
}
impl HasSqlType<NaiveDate> for MySql {
impl Type<NaiveDate> for MySql {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::DATE)
}
@ -104,7 +104,7 @@ impl Decode<MySql> for NaiveDate {
}
}
impl HasSqlType<NaiveDateTime> for MySql {
impl Type<NaiveDateTime> for MySql {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::DATETIME)
}

View File

@ -3,7 +3,7 @@ use crate::encode::Encode;
use crate::mysql::protocol::TypeId;
use crate::mysql::types::MySqlTypeInfo;
use crate::mysql::MySql;
use crate::types::HasSqlType;
use crate::types::Type;
/// The equivalent MySQL type for `f32` is `FLOAT`.
///
@ -18,7 +18,7 @@ use crate::types::HasSqlType;
/// // (This is expected behavior for floating points and happens both in Rust and in MySQL)
/// assert_ne!(10.2f32 as f64, 10.2f64);
/// ```
impl HasSqlType<f32> for MySql {
impl Type<f32> for MySql {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::FLOAT)
}
@ -40,7 +40,7 @@ impl Decode<MySql> for f32 {
///
/// Note that `DOUBLE` is a floating-point type and cannot represent some fractional values
/// exactly.
impl HasSqlType<f64> for MySql {
impl Type<f64> for MySql {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::DOUBLE)
}

View File

@ -6,9 +6,9 @@ use crate::io::{Buf, BufMut};
use crate::mysql::protocol::TypeId;
use crate::mysql::types::MySqlTypeInfo;
use crate::mysql::MySql;
use crate::types::HasSqlType;
use crate::types::Type;
impl HasSqlType<i8> for MySql {
impl Type<i8> for MySql {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::TINY_INT)
}
@ -26,7 +26,7 @@ impl Decode<MySql> for i8 {
}
}
impl HasSqlType<i16> for MySql {
impl Type<i16> for MySql {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::SMALL_INT)
}
@ -44,7 +44,7 @@ impl Decode<MySql> for i16 {
}
}
impl HasSqlType<i32> for MySql {
impl Type<i32> for MySql {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::INT)
}
@ -62,7 +62,7 @@ impl Decode<MySql> for i32 {
}
}
impl HasSqlType<i64> for MySql {
impl Type<i64> for MySql {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::BIG_INT)
}

View File

@ -8,9 +8,9 @@ use crate::mysql::io::{BufExt, BufMutExt};
use crate::mysql::protocol::TypeId;
use crate::mysql::types::MySqlTypeInfo;
use crate::mysql::MySql;
use crate::types::HasSqlType;
use crate::types::Type;
impl HasSqlType<str> for MySql {
impl Type<str> for MySql {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo {
id: TypeId::TEXT,
@ -28,9 +28,9 @@ impl Encode<MySql> for str {
}
// TODO: Do we need the [HasSqlType] for String
impl HasSqlType<String> for MySql {
impl Type<String> for MySql {
fn type_info() -> MySqlTypeInfo {
<Self as HasSqlType<&str>>::type_info()
<Self as Type<&str>>::type_info()
}
}

View File

@ -6,9 +6,9 @@ use crate::io::{Buf, BufMut};
use crate::mysql::protocol::TypeId;
use crate::mysql::types::MySqlTypeInfo;
use crate::mysql::MySql;
use crate::types::HasSqlType;
use crate::types::Type;
impl HasSqlType<u8> for MySql {
impl Type<u8> for MySql {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::unsigned(TypeId::TINY_INT)
}
@ -26,7 +26,7 @@ impl Decode<MySql> for u8 {
}
}
impl HasSqlType<u16> for MySql {
impl Type<u16> for MySql {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::unsigned(TypeId::SMALL_INT)
}
@ -44,7 +44,7 @@ impl Decode<MySql> for u16 {
}
}
impl HasSqlType<u32> for MySql {
impl Type<u32> for MySql {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::unsigned(TypeId::INT)
}
@ -62,7 +62,7 @@ impl Decode<MySql> for u32 {
}
}
impl HasSqlType<u64> for MySql {
impl Type<u64> for MySql {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::unsigned(TypeId::BIG_INT)
}

View File

@ -1,10 +1,11 @@
use crate::{Connect, Connection};
use crate::{Connect, Connection, Executor};
use futures_core::future::BoxFuture;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use std::time::Instant;
use super::inner::{DecrementSizeGuard, SharedPool};
use crate::describe::Describe;
/// A connection checked out from [`Pool`][crate::Pool].
///
@ -68,6 +69,20 @@ where
live.float(&self.pool).into_idle().close().await
})
}
#[inline]
fn ping(&mut self) -> BoxFuture<crate::Result<()>> {
Box::pin(self.deref_mut().ping())
}
#[doc(hidden)]
#[inline]
fn describe<'e, 'q: 'e>(
&'e mut self,
query: &'q str,
) -> BoxFuture<'e, crate::Result<Describe<Self::Database>>> {
Box::pin(self.deref_mut().describe(query))
}
}
/// Returns the connection to the [`Pool`][crate::Pool] it was checked-out from.
@ -168,8 +183,7 @@ impl<'s, C> Floating<'s, Idle<C>> {
where
C: Connection,
{
// TODO self.live.raw.ping().await
todo!()
self.live.raw.ping().await
}
pub fn into_live(self) -> Floating<'s, Live<C>> {

View File

@ -8,84 +8,89 @@ use crate::{
describe::Describe,
executor::Executor,
pool::Pool,
Database,
Cursor, Database,
};
use super::PoolConnection;
use crate::database::HasCursor;
use crate::executor::Execute;
impl<'p, C> Executor<'p> for &'p Pool<C>
impl<'p, C, DB> Executor<'p> for &'p Pool<C>
where
C: Connect,
C: Connect<Database = DB>,
DB: Database<Connection = C>,
DB: for<'c, 'q> HasCursor<'c, 'q>,
for<'con> &'con mut C: Executor<'con>,
{
type Database = DB;
fn execute<'q, E>(self, query: E) -> <Self::Database as HasCursor<'p, 'q>>::Cursor
where
E: Execute<'q, Self::Database>,
{
DB::Cursor::from_pool(self, query)
}
#[inline]
fn execute_by_ref<'q, 'e, E>(
&'e mut self,
query: E,
) -> <Self::Database as HasCursor<'_, 'q>>::Cursor
where
E: Execute<'q, Self::Database>,
{
self.execute(query)
}
}
impl<'c, C, DB> Executor<'c> for &'c mut PoolConnection<C>
where
C: Connect<Database = DB>,
DB: Database<Connection = C>,
DB: for<'c2, 'q> HasCursor<'c2, 'q, Database = DB>,
for<'con> &'con mut C: Executor<'con>,
{
type Database = C::Database;
fn execute<'q, E>(self, query: E) -> <Self::Database as HasCursor<'p>>::Cursor
fn execute<'q, E>(self, query: E) -> <Self::Database as HasCursor<'c, 'q>>::Cursor
where
E: Execute<'q, Self::Database>,
{
todo!()
DB::Cursor::from_connection(&mut **self, query)
}
#[inline]
fn execute_by_ref<'q, 'e, E>(
&'e mut self,
query: E,
) -> <Self::Database as HasCursor<'_>>::Cursor
) -> <Self::Database as HasCursor<'_, 'q>>::Cursor
where
E: Execute<'q, Self::Database>,
{
todo!()
self.execute(query)
}
}
impl<'c, C> Executor<'c> for &'c mut PoolConnection<C>
impl<C, DB> Executor<'static> for PoolConnection<C>
where
C: Connect,
for<'con> &'con mut C: Executor<'con>,
C: Connect<Database = DB>,
DB: Database<Connection = C>,
DB: for<'c, 'q> HasCursor<'c, 'q, Database = DB>,
{
type Database = C::Database;
type Database = DB;
fn execute<'q, E>(self, query: E) -> <Self::Database as HasCursor<'c>>::Cursor
fn execute<'q, E>(self, query: E) -> <DB as HasCursor<'static, 'q>>::Cursor
where
E: Execute<'q, Self::Database>,
{
todo!()
DB::Cursor::from_connection(self, query)
}
fn execute_by_ref<'q, 'e, E>(
&'e mut self,
query: E,
) -> <Self::Database as HasCursor<'_>>::Cursor
#[inline]
fn execute_by_ref<'q, 'e, E>(&'e mut self, query: E) -> <DB as HasCursor<'_, 'q>>::Cursor
where
E: Execute<'q, Self::Database>,
{
todo!()
}
}
impl<C> Executor<'static> for PoolConnection<C>
where
C: Connect,
// for<'con> &'con mut C: Executor<'con>,
{
type Database = C::Database;
fn execute<'q, E>(self, query: E) -> <Self::Database as HasCursor<'static>>::Cursor
where
E: Execute<'q, Self::Database>,
{
unimplemented!()
}
fn execute_by_ref<'q, 'e, E>(
&'e mut self,
query: E,
) -> <Self::Database as HasCursor<'_>>::Cursor
where
E: Execute<'q, Self::Database>,
{
todo!()
DB::Cursor::from_connection(&mut **self, query)
}
}

View File

@ -20,13 +20,15 @@ mod inner;
mod options;
pub use self::options::Builder;
use crate::Database;
/// A pool of database connections.
pub struct Pool<C>(Arc<SharedPool<C>>);
impl<C> Pool<C>
impl<C, DB> Pool<C>
where
C: Connect,
C: Connect<Database = DB>,
DB: Database<Connection = C>,
{
/// Creates a connection pool with the default configuration.
///

View File

@ -2,6 +2,7 @@ use std::{marker::PhantomData, time::Duration};
use super::Pool;
use crate::connection::Connect;
use crate::Database;
/// Builder for [Pool].
pub struct Builder<C> {
@ -9,7 +10,11 @@ pub struct Builder<C> {
options: Options,
}
impl<C> Builder<C> {
impl<C, DB> Builder<C>
where
C: Connect<Database = DB>,
DB: Database<Connection = C>,
{
/// Get a new builder with default options.
///
/// See the source of this method for current defaults.
@ -108,7 +113,11 @@ impl<C> Builder<C> {
}
}
impl<C> Default for Builder<C> {
impl<C, DB> Default for Builder<C>
where
C: Connect<Database = DB>,
DB: Database<Connection = C>,
{
fn default() -> Self {
Self::new()
}

View File

@ -3,7 +3,7 @@ use byteorder::{ByteOrder, NetworkEndian};
use crate::arguments::Arguments;
use crate::encode::{Encode, IsNull};
use crate::io::BufMut;
use crate::types::HasSqlType;
use crate::types::Type;
use crate::Postgres;
#[derive(Default)]
@ -25,14 +25,13 @@ impl Arguments for PgArguments {
fn add<T>(&mut self, value: T)
where
Self::Database: HasSqlType<T>,
T: Type<Self::Database>,
T: Encode<Self::Database>,
{
// TODO: When/if we receive types that do _not_ support BINARY, we need to check here
// TODO: There is no need to be explicit unless we are expecting mixed BINARY / TEXT
self.types
.push(<Postgres as HasSqlType<T>>::type_info().id.0);
self.types.push(<T as Type<Postgres>>::type_info().id.0);
let pos = self.values.len();

View File

@ -1,16 +1,26 @@
use std::convert::TryInto;
use std::ops::Range;
use byteorder::NetworkEndian;
use futures_core::future::BoxFuture;
use std::net::Shutdown;
use futures_core::Future;
use futures_util::TryFutureExt;
use crate::cache::StatementCache;
use crate::connection::{Connect, Connection};
use crate::describe::{Column, Describe};
use crate::io::{Buf, BufStream, MaybeTlsStream};
use crate::postgres::protocol::{self, Authentication, Decode, Encode, Message, StatementId};
use crate::postgres::{sasl, PgError};
use crate::postgres::protocol::{
self, Authentication, AuthenticationMd5, AuthenticationSasl, Decode, Encode, Message,
ParameterDescription, PasswordMessage, RowDescription, StartupMessage, StatementId, Terminate,
};
use crate::postgres::sasl;
use crate::postgres::stream::PgStream;
use crate::postgres::{PgError, PgTypeInfo};
use crate::url::Url;
use crate::{Postgres, Result};
use crate::{Error, Executor, Postgres};
// TODO: TLS
/// An asynchronous connection to a [Postgres][super::Postgres] database.
///
@ -73,27 +83,16 @@ use crate::{Postgres, Result};
/// against the hostname in the server certificate, so they must be the same for the TLS
/// upgrade to succeed.
pub struct PgConnection {
pub(super) stream: BufStream<MaybeTlsStream>,
// Map of query to statement id
pub(super) statement_cache: StatementCache<StatementId>,
// Next statement id
pub(super) stream: PgStream,
pub(super) next_statement_id: u32,
pub(super) is_ready: bool,
// Process ID of the Backend
process_id: u32,
// Backend-unique key to use to send a cancel query message to the server
secret_key: u32,
// Is there a query in progress; are we ready to continue
pub(super) ready: bool,
// TODO: Think of a better way to do this, better name perhaps?
pub(super) data_row_values_buf: Vec<Option<Range<u32>>>,
}
impl PgConnection {
// https://www.postgresql.org/docs/12/protocol-flow.html#id-1.10.5.7.3
async fn startup(&mut self, url: &Url) -> Result<()> {
async fn startup(stream: &mut PgStream, url: &Url) -> crate::Result<()> {
// Defaults to postgres@.../postgres
let username = url.username().unwrap_or("postgres");
let database = url.database().unwrap_or("postgres");
@ -117,42 +116,47 @@ impl PgConnection {
("client_encoding", "UTF-8"),
];
protocol::StartupMessage { params }.encode(self.stream.buffer_mut());
self.stream.flush().await?;
stream.write(StartupMessage { params });
stream.flush().await?;
while let Some(message) = self.receive().await? {
match message {
Message::Authentication(auth) => {
match *auth {
protocol::Authentication::Ok => {
// Do nothing. No password is needed to continue.
loop {
match stream.read().await? {
Message::Authentication => match Authentication::read(stream.buffer())? {
Authentication::Ok => {
// do nothing. no password is needed to continue.
}
protocol::Authentication::ClearTextPassword => {
protocol::PasswordMessage::ClearText(
Authentication::CleartextPassword => {
stream.write(PasswordMessage::ClearText(
&url.password().unwrap_or_default(),
)
.encode(self.stream.buffer_mut());
));
self.stream.flush().await?;
stream.flush().await?;
}
protocol::Authentication::Md5Password { salt } => {
protocol::PasswordMessage::Md5 {
Authentication::Md5Password => {
// TODO: Just reference the salt instead of returning a stack array
// TODO: Better way to make sure we skip the first 4 bytes here
let data = AuthenticationMd5::read(&stream.buffer()[4..])?;
stream.write(PasswordMessage::Md5 {
password: &url.password().unwrap_or_default(),
user: username,
salt,
}
.encode(self.stream.buffer_mut());
salt: data.salt,
});
self.stream.flush().await?;
stream.flush().await?;
}
protocol::Authentication::Sasl { mechanisms } => {
Authentication::Sasl => {
// TODO: Make this iterative for traversing the mechanisms to remove the allocation
// TODO: Better way to make sure we skip the first 4 bytes here
let data = AuthenticationSasl::read(&stream.buffer()[4..])?;
let mut has_sasl: bool = false;
let mut has_sasl_plus: bool = false;
for mechanism in &*mechanisms {
for mechanism in &*data.mechanisms {
match &**mechanism {
"SCRAM-SHA-256" => {
has_sasl = true;
@ -170,43 +174,41 @@ impl PgConnection {
if has_sasl || has_sasl_plus {
// TODO: Handle -PLUS differently if we're in a TLS stream
sasl::authenticate(
self,
username,
&url.password().unwrap_or_default(),
)
sasl::authenticate(stream, username, &url.password().unwrap_or_default())
.await?;
} else {
return Err(protocol_err!(
"unsupported SASL auth mechanisms: {:?}",
mechanisms
data.mechanisms
)
.into());
}
}
auth => {
return Err(protocol_err!(
"requires unimplemented authentication method: {:?}",
auth
)
.into());
}
return Err(
protocol_err!("requested unsupported authentication: {:?}", auth).into(),
);
}
},
Message::BackendKeyData => {
// do nothing. we do not care about the server values here.
// todo: we should care and store these on the connection
}
Message::BackendKeyData(body) => {
self.process_id = body.process_id;
self.secret_key = body.secret_key;
Message::ParameterStatus => {
// do nothing. we do not care about the server values here.
}
Message::ReadyForQuery(_) => {
// Connection fully established and ready to receive queries.
Message::ReadyForQuery => {
// done. connection is now fully established and can accept
// queries for execution.
break;
}
message => {
return Err(protocol_err!("received unexpected message: {:?}", message).into());
type_ => {
return Err(protocol_err!("unexpected message: {:?}", type_).into());
}
}
}
@ -214,160 +216,146 @@ impl PgConnection {
Ok(())
}
// https://www.postgresql.org/docs/devel/protocol-flow.html#id-1.10.5.7.10
async fn terminate(mut self) -> Result<()> {
protocol::Terminate.encode(self.stream.buffer_mut());
self.stream.flush().await?;
self.stream.stream.shutdown(Shutdown::Both)?;
// https://www.postgresql.org/docs/12/protocol-flow.html#id-1.10.5.7.10
async fn terminate(mut stream: PgStream) -> crate::Result<()> {
stream.write(Terminate);
stream.flush().await?;
stream.shutdown()?;
Ok(())
}
// Wait and return the next message to be received from Postgres.
pub(super) async fn receive(&mut self) -> Result<Option<Message>> {
loop {
// Read the message header (id + len)
let mut header = ret_if_none!(self.stream.peek(5).await?);
let id = header.get_u8()?;
let len = (header.get_u32::<NetworkEndian>()? - 4) as usize;
// Read the message body
self.stream.consume(5);
let body = ret_if_none!(self.stream.peek(len).await?);
let message = match id {
b'N' | b'E' => Message::Response(Box::new(protocol::Response::decode(body)?)),
b'D' => Message::DataRow(protocol::DataRow::decode(body)?),
b'S' => {
Message::ParameterStatus(Box::new(protocol::ParameterStatus::decode(body)?))
}
b'Z' => Message::ReadyForQuery(protocol::ReadyForQuery::decode(body)?),
b'R' => Message::Authentication(Box::new(protocol::Authentication::decode(body)?)),
b'K' => Message::BackendKeyData(protocol::BackendKeyData::decode(body)?),
b'C' => Message::CommandComplete(protocol::CommandComplete::decode(body)?),
b'A' => Message::NotificationResponse(Box::new(
protocol::NotificationResponse::decode(body)?,
)),
b'1' => Message::ParseComplete,
b'2' => Message::BindComplete,
b'3' => Message::CloseComplete,
b'n' => Message::NoData,
b's' => Message::PortalSuspended,
b't' => Message::ParameterDescription(Box::new(
protocol::ParameterDescription::decode(body)?,
)),
b'T' => Message::RowDescription(Box::new(protocol::RowDescription::decode(body)?)),
id => {
return Err(protocol_err!("received unknown message id: {:?}", id).into());
}
};
self.stream.consume(len);
match message {
Message::ParameterStatus(_body) => {
// TODO: not sure what to do with these yet
}
Message::Response(body) => {
if body.severity.is_error() {
// This is an error, stop the world and bubble as an error
return Err(PgError(body).into());
} else {
// This is a _warning_
// TODO: Log the warning
}
}
message => {
return Ok(Some(message));
}
}
}
}
}
impl PgConnection {
pub(super) async fn establish(url: Result<Url>) -> Result<Self> {
pub(super) async fn new(url: crate::Result<Url>) -> crate::Result<Self> {
let url = url?;
let mut stream = PgStream::new(&url).await?;
let stream = MaybeTlsStream::connect(&url, 5432).await?;
let mut self_ = Self {
stream: BufStream::new(stream),
process_id: 0,
secret_key: 0,
// Important to start at 1 as 0 means "unnamed" in our protocol
startup(&mut stream, &url).await?;
Ok(Self {
stream,
data_row_values_buf: Vec::new(),
next_statement_id: 1,
statement_cache: StatementCache::new(),
ready: true,
is_ready: true,
})
}
pub(super) async fn wait_until_ready(&mut self) -> crate::Result<()> {
// depending on how the previous query finished we may need to continue
// pulling messages from the stream until we receive a [ReadyForQuery] message
// postgres sends the [ReadyForQuery] message when it's fully complete with processing
// the previous query
if !self.is_ready {
loop {
if let Message::ReadyForQuery = self.stream.read().await? {
// we are now ready to go
self.is_ready = true;
break;
}
}
}
Ok(())
}
async fn describe<'e, 'q: 'e>(
&'e mut self,
query: &'q str,
) -> crate::Result<Describe<Postgres>> {
let statement = self.write_prepare(query, &Default::default());
self.write_describe(protocol::Describe::Statement(statement));
self.write_sync();
self.stream.flush().await?;
self.wait_until_ready().await?;
let params = loop {
match self.stream.read().await? {
Message::ParseComplete => {
// ignore complete messsage
// continue
}
Message::ParameterDescription => {
break ParameterDescription::read(self.stream.buffer())?;
}
message => {
return Err(protocol_err!(
"expected ParameterDescription; received {:?}",
message
)
.into());
}
};
};
let ssl_mode = url.get_param("sslmode").unwrap_or("prefer".into());
let result = match self.stream.read().await? {
Message::NoData => None,
Message::RowDescription => Some(RowDescription::read(self.stream.buffer())?),
match &*ssl_mode {
// TODO: on "allow" retry with TLS if startup fails
"disable" | "allow" => (),
#[cfg(feature = "tls")]
"prefer" => {
if !self_.try_ssl(&url, true, true).await? {
log::warn!("server does not support TLS, falling back to unsecured connection")
}
}
#[cfg(not(feature = "tls"))]
"prefer" => log::info!("compiled without TLS, skipping upgrade"),
#[cfg(feature = "tls")]
"require" | "verify-ca" | "verify-full" => {
if !self_
.try_ssl(
&url,
ssl_mode == "require", // false for both verify-ca and verify-full
ssl_mode != "verify-full", // false for only verify-full
message => {
return Err(protocol_err!(
"expected RowDescription or NoData; received {:?}",
message
)
.await?
{
return Err(tls_err!("Postgres server does not support TLS").into());
}
.into());
}
};
#[cfg(not(feature = "tls"))]
"require" | "verify-ca" | "verify-full" => {
return Err(tls_err!(
"sslmode {:?} unsupported; SQLx was compiled without `tls` feature",
ssl_mode
)
.into())
}
_ => return Err(tls_err!("unknown `sslmode` value: {:?}", ssl_mode).into()),
}
self_.stream.clear_bufs();
self_.startup(&url).await?;
Ok(self_)
Ok(Describe {
param_types: params
.ids
.iter()
.map(|id| PgTypeInfo::new(*id))
.collect::<Vec<_>>()
.into_boxed_slice(),
result_columns: result
.map(|r| r.fields)
.unwrap_or_default()
.into_vec()
.into_iter()
// TODO: Should [Column] just wrap [protocol::Field] ?
.map(|field| Column {
name: field.name,
table_id: field.table_id,
type_info: PgTypeInfo::new(field.type_id),
})
.collect::<Vec<_>>()
.into_boxed_slice(),
})
}
}
impl Connect for PgConnection {
fn connect<T>(url: T) -> BoxFuture<'static, Result<PgConnection>>
fn connect<T>(url: T) -> BoxFuture<'static, crate::Result<PgConnection>>
where
T: TryInto<Url, Error = crate::Error>,
Self: Sized,
{
Box::pin(PgConnection::establish(url.try_into()))
Box::pin(PgConnection::new(url.try_into()))
}
}
impl Connection for PgConnection {
type Database = Postgres;
fn close(self) -> BoxFuture<'static, Result<()>> {
Box::pin(self.terminate())
fn close(self) -> BoxFuture<'static, crate::Result<()>> {
Box::pin(terminate(self.stream))
}
fn ping(&mut self) -> BoxFuture<crate::Result<()>> {
Box::pin(self.execute("SELECT 1").map_ok(|_| ()))
}
#[doc(hidden)]
fn describe<'e, 'q: 'e>(
&'e mut self,
query: &'q str,
) -> BoxFuture<'e, crate::Result<Describe<Self::Database>>> {
Box::pin(self.describe(query))
}
}

View File

@ -1,56 +1,329 @@
use std::future::Future;
use std::mem;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use async_stream::try_stream;
use futures_core::future::BoxFuture;
use futures_core::stream::BoxStream;
use crate::cursor::Cursor;
use crate::connection::{ConnectionSource, MaybeOwnedConnection};
use crate::cursor::{Cursor, MapRowFn};
use crate::database::HasRow;
use crate::postgres::protocol::StatementId;
use crate::postgres::PgConnection;
use crate::Postgres;
use crate::executor::Execute;
use crate::pool::{Pool, PoolConnection};
use crate::postgres::protocol::{CommandComplete, DataRow, Message, StatementId};
use crate::postgres::{PgArguments, PgConnection, PgRow};
use crate::{Database, Postgres};
pub struct PgCursor<'a> {
statement: StatementId,
connection: &'a mut PgConnection,
enum State<'c, 'q> {
Query(&'q str, Option<PgArguments>),
NextRow,
// Used for `impl Future`
Resolve(BoxFuture<'c, crate::Result<MaybeOwnedConnection<'c, PgConnection>>>),
AffectedRows(BoxFuture<'c, crate::Result<u64>>),
}
impl<'a> PgCursor<'a> {
pub(super) fn from_connection(
connection: &'a mut PgConnection,
statement: StatementId,
) -> Self {
Self {
connection,
statement,
}
}
pub struct PgCursor<'c, 'q> {
source: ConnectionSource<'c, PgConnection>,
state: State<'c, 'q>,
}
impl<'a> Cursor<'a> for PgCursor<'a> {
impl<'c, 'q> Cursor<'c, 'q> for PgCursor<'c, 'q> {
type Database = Postgres;
fn first(self) -> BoxFuture<'a, crate::Result<Option<<Self::Database as HasRow>::Row>>> {
todo!()
}
fn next(&mut self) -> BoxFuture<crate::Result<Option<<Self::Database as HasRow>::Row>>> {
todo!()
}
fn map<T, F>(self, f: F) -> BoxStream<'a, crate::Result<T>>
#[doc(hidden)]
fn from_pool<E>(pool: &Pool<<Self::Database as Database>::Connection>, query: E) -> Self
where
F: Fn(<Self::Database as HasRow>::Row) -> T,
Self: Sized,
E: Execute<'q, Self::Database>,
{
todo!()
let (query, arguments) = query.into_parts();
Self {
// note: pool is internally reference counted
source: ConnectionSource::Pool(pool.clone()),
state: State::Query(query, arguments),
}
}
impl<'a> Future for PgCursor<'a> {
#[doc(hidden)]
fn from_connection<E, C>(conn: C, query: E) -> Self
where
Self: Sized,
C: Into<MaybeOwnedConnection<'c, <Self::Database as Database>::Connection>>,
E: Execute<'q, Self::Database>,
{
let (query, arguments) = query.into_parts();
Self {
// note: pool is internally reference counted
source: ConnectionSource::Connection(conn.into()),
state: State::Query(query, arguments),
}
}
fn first(self) -> BoxFuture<'c, crate::Result<Option<<Self::Database as HasRow<'c>>::Row>>>
where
'q: 'c,
{
Box::pin(first(self))
}
fn next(&mut self) -> BoxFuture<crate::Result<Option<<Self::Database as HasRow<'_>>::Row>>> {
Box::pin(next(self))
}
fn map<T, F>(mut self, f: F) -> BoxStream<'c, crate::Result<T>>
where
F: MapRowFn<Self::Database, T>,
T: 'c + Send + Unpin,
'q: 'c,
{
Box::pin(try_stream! {
while let Some(row) = self.next().await? {
yield f.call(row);
}
})
}
}
impl<'s, 'q> Future for PgCursor<'s, 'q> {
type Output = crate::Result<u64>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
todo!()
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop {
match &mut self.state {
State::Query(q, arguments) => {
// todo: existential types can remove both the boxed futures
// and this allocation
let query = q.to_owned();
let arguments = mem::take(arguments);
self.state = State::Resolve(Box::pin(resolve(
mem::take(&mut self.source),
query,
arguments,
)));
}
State::Resolve(fut) => {
match fut.as_mut().poll(cx) {
Poll::Pending => {
return Poll::Pending;
}
Poll::Ready(conn) => {
let conn = conn?;
self.state = State::AffectedRows(Box::pin(affected_rows(conn)));
// continue
}
}
}
State::NextRow => {
panic!("PgCursor must not be polled after being used");
}
State::AffectedRows(fut) => {
return fut.as_mut().poll(cx);
}
}
}
}
}
// write out query to the connection stream
async fn write(
conn: &mut PgConnection,
query: &str,
arguments: Option<PgArguments>,
) -> crate::Result<()> {
// TODO: Handle [arguments] being None. This should be a SIMPLE query.
let arguments = arguments.unwrap();
// Check the statement cache for a statement ID that matches the given query
// If it doesn't exist, we generate a new statement ID and write out [Parse] to the
// connection command buffer
let statement = conn.write_prepare(query, &arguments);
// Next, [Bind] attaches the arguments to the statement and creates a named portal
conn.write_bind("", statement, &arguments);
// Next, [Describe] will return the expected result columns and types
// Conditionally run [Describe] only if the results have not been cached
// if !self.statement_cache.has_columns(statement) {
// self.write_describe(protocol::Describe::Portal(""));
// }
// Next, [Execute] then executes the named portal
conn.write_execute("", 0);
// Finally, [Sync] asks postgres to process the messages that we sent and respond with
// a [ReadyForQuery] message when it's completely done. Theoretically, we could send
// dozens of queries before a [Sync] and postgres can handle that. Execution on the server
// is still serial but it would reduce round-trips. Some kind of builder pattern that is
// termed batching might suit this.
conn.write_sync();
conn.wait_until_ready().await?;
conn.stream.flush().await?;
conn.is_ready = false;
Ok(())
}
async fn resolve(
mut source: ConnectionSource<'_, PgConnection>,
query: String,
arguments: Option<PgArguments>,
) -> crate::Result<MaybeOwnedConnection<'_, PgConnection>> {
let mut conn = source.resolve_by_ref().await?;
write(&mut *conn, &query, arguments).await?;
Ok(source.into_connection())
}
async fn affected_rows(mut conn: MaybeOwnedConnection<'_, PgConnection>) -> crate::Result<u64> {
conn.wait_until_ready().await?;
conn.stream.flush().await?;
conn.is_ready = false;
let mut rows = 0;
loop {
match conn.stream.read().await? {
Message::ParseComplete | Message::BindComplete => {
// ignore x_complete messages
}
Message::DataRow => {
// ignore rows
// TODO: should we log or something?
}
Message::CommandComplete => {
rows += CommandComplete::read(conn.stream.buffer())?.affected_rows;
}
Message::ReadyForQuery => {
// done
break;
}
message => {
return Err(protocol_err!("unexpected message: {:?}", message).into());
}
}
}
Ok(rows)
}
async fn next<'a, 'c: 'a, 'q: 'a>(
cursor: &'a mut PgCursor<'c, 'q>,
) -> crate::Result<Option<PgRow<'a>>> {
let mut conn = cursor.source.resolve_by_ref().await?;
match cursor.state {
State::Query(q, ref mut arguments) => {
// write out the query to the connection
write(&mut *conn, q, arguments.take()).await?;
// next time we come through here, skip this block
cursor.state = State::NextRow;
}
State::Resolve(_) | State::AffectedRows(_) => {
panic!("`PgCursor` must not be used after being polled");
}
State::NextRow => {
// grab the next row
}
}
loop {
match conn.stream.read().await? {
Message::ParseComplete | Message::BindComplete => {
// ignore x_complete messages
}
Message::CommandComplete => {
// no more rows
break;
}
Message::DataRow => {
let data = DataRow::read(&mut *conn)?;
return Ok(Some(PgRow {
connection: conn,
columns: Arc::default(),
data,
}));
}
message => {
return Err(protocol_err!("unexpected message: {:?}", message).into());
}
}
}
Ok(None)
}
async fn first<'c, 'q>(mut cursor: PgCursor<'c, 'q>) -> crate::Result<Option<PgRow<'c>>> {
let mut conn = cursor.source.resolve().await?;
match cursor.state {
State::Query(q, ref mut arguments) => {
// write out the query to the connection
write(&mut conn, q, arguments.take()).await?;
}
State::NextRow => {
// just grab the next row as the first
}
State::Resolve(_) | State::AffectedRows(_) => {
panic!("`PgCursor` must not be used after being polled");
}
}
loop {
match conn.stream.read().await? {
Message::ParseComplete | Message::BindComplete => {
// ignore x_complete messages
}
Message::CommandComplete => {
// no more rows
break;
}
Message::DataRow => {
let data = DataRow::read(&mut conn)?;
return Ok(Some(PgRow {
connection: conn,
columns: Arc::default(),
data,
}));
}
message => {
return Err(protocol_err!("unexpected message: {:?}", message).into());
}
}
}
Ok(None)
}

View File

@ -13,18 +13,18 @@ impl Database for Postgres {
type TableId = u32;
}
impl HasRow for Postgres {
impl<'a> HasRow<'a> for Postgres {
// TODO: Can we drop the `type Database = _`
type Database = Postgres;
type Row = super::PgRow;
type Row = super::PgRow<'a>;
}
impl<'a> HasCursor<'a> for Postgres {
impl<'s, 'q> HasCursor<'s, 'q> for Postgres {
// TODO: Can we drop the `type Database = _`
type Database = Postgres;
type Cursor = super::PgCursor<'a>;
type Cursor = super::PgCursor<'s, 'q>;
}
impl<'a> HasRawValue<'a> for Postgres {

View File

@ -1,7 +1,7 @@
use crate::error::DatabaseError;
use crate::postgres::protocol::Response;
pub struct PgError(pub(super) Box<Response>);
pub struct PgError(pub(super) Response);
impl DatabaseError for PgError {
fn message(&self) -> &str {

View File

@ -2,37 +2,36 @@ use std::collections::HashMap;
use std::io;
use std::sync::Arc;
use crate::cursor::Cursor;
use crate::executor::{Execute, Executor};
use crate::postgres::protocol::{self, Encode, Message, StatementId, TypeFormat};
use crate::postgres::{PgArguments, PgCursor, PgRow, PgTypeInfo, Postgres};
use crate::postgres::protocol::{self, Encode, StatementId, TypeFormat};
use crate::postgres::{PgArguments, PgConnection, PgCursor, PgRow, PgTypeInfo, Postgres};
impl PgConnection {
pub(crate) fn write_prepare(&mut self, query: &str, args: &PgArguments) -> StatementId {
// TODO: check query cache
impl super::PgConnection {
fn write_prepare(&mut self, query: &str, args: &PgArguments) -> StatementId {
if let Some(&id) = self.statement_cache.get(query) {
id
} else {
let id = StatementId(self.next_statement_id);
self.next_statement_id += 1;
protocol::Parse {
self.stream.write(protocol::Parse {
statement: id,
query,
param_types: &*args.types,
}
.encode(self.stream.buffer_mut());
});
self.statement_cache.put(query.to_owned(), id);
// TODO: write to query cache
id
}
pub(crate) fn write_describe(&mut self, d: protocol::Describe) {
self.stream.write(d);
}
fn write_describe(&mut self, d: protocol::Describe) {
d.encode(self.stream.buffer_mut())
}
fn write_bind(&mut self, portal: &str, statement: StatementId, args: &PgArguments) {
protocol::Bind {
pub(crate) fn write_bind(&mut self, portal: &str, statement: StatementId, args: &PgArguments) {
self.stream.write(protocol::Bind {
portal,
statement,
formats: &[TypeFormat::Binary],
@ -40,59 +39,30 @@ impl super::PgConnection {
values_len: args.types.len() as i16,
values: &*args.values,
result_formats: &[TypeFormat::Binary],
}
.encode(self.stream.buffer_mut());
});
}
fn write_execute(&mut self, portal: &str, limit: i32) {
protocol::Execute { portal, limit }.encode(self.stream.buffer_mut());
pub(crate) fn write_execute(&mut self, portal: &str, limit: i32) {
self.stream.write(protocol::Execute { portal, limit });
}
fn write_sync(&mut self) {
protocol::Sync.encode(self.stream.buffer_mut());
pub(crate) fn write_sync(&mut self) {
self.stream.write(protocol::Sync);
}
}
impl<'e> Executor<'e> for &'e mut super::PgConnection {
type Database = Postgres;
fn execute<'q, E>(self, query: E) -> PgCursor<'e>
fn execute<'q, E>(self, query: E) -> PgCursor<'e, 'q>
where
E: Execute<'q, Self::Database>,
{
let (query, arguments) = query.into_parts();
// TODO: Handle [arguments] being None. This should be a SIMPLE query.
let arguments = arguments.unwrap();
// Check the statement cache for a statement ID that matches the given query
// If it doesn't exist, we generate a new statement ID and write out [Parse] to the
// connection command buffer
let statement = self.write_prepare(query, &arguments);
// Next, [Bind] attaches the arguments to the statement and creates a named portal
self.write_bind("", statement, &arguments);
// Next, [Describe] will return the expected result columns and types
// Conditionally run [Describe] only if the results have not been cached
if !self.statement_cache.has_columns(statement) {
self.write_describe(protocol::Describe::Portal(""));
PgCursor::from_connection(self, query)
}
// Next, [Execute] then executes the named portal
self.write_execute("", 0);
// Finally, [Sync] asks postgres to process the messages that we sent and respond with
// a [ReadyForQuery] message when it's completely done. Theoretically, we could send
// dozens of queries before a [Sync] and postgres can handle that. Execution on the server
// is still serial but it would reduce round-trips. Some kind of builder pattern that is
// termed batching might suit this.
self.write_sync();
PgCursor::from_connection(self, statement)
}
fn execute_by_ref<'q, E>(&mut self, query: E) -> PgCursor<'_>
#[inline]
fn execute_by_ref<'q, E>(&mut self, query: E) -> PgCursor<'_, 'q>
where
E: Execute<'q, Self::Database>,
{

View File

@ -17,7 +17,8 @@ mod executor;
mod protocol;
mod row;
mod sasl;
mod tls;
mod stream;
// mod tls;
mod types;
/// An alias for [`Pool`][crate::Pool], specialized for **Postgres**.

View File

@ -5,95 +5,120 @@ use std::str;
#[derive(Debug)]
pub enum Authentication {
/// Authentication was successful.
/// The authentication exchange is successfully completed.
Ok,
/// Kerberos V5 authentication is required.
/// The frontend must now take part in a Kerberos V5 authentication dialog (not described
/// here, part of the Kerberos specification) with the server. If this is successful,
/// the server responds with an `AuthenticationOk`, otherwise it responds
/// with an `ErrorResponse`. This is no longer supported.
KerberosV5,
/// A clear-text password is required.
ClearTextPassword,
/// The frontend must now send a `PasswordMessage` containing the password in clear-text form.
/// If this is the correct password, the server responds with an `AuthenticationOk`, otherwise it
/// responds with an `ErrorResponse`.
CleartextPassword,
/// An MD5-encrypted password is required.
Md5Password { salt: [u8; 4] },
/// The frontend must now send a `PasswordMessage` containing the password (with user name)
/// encrypted via MD5, then encrypted again using the 4-byte random salt specified in the
/// `AuthenticationMD5Password` message. If this is the correct password, the server responds
/// with an `AuthenticationOk`, otherwise it responds with an `ErrorResponse`.
Md5Password,
/// An SCM credentials message is required.
/// This response is only possible for local Unix-domain connections on platforms that support
/// SCM credential messages. The frontend must issue an SCM credential message and then
/// send a single data byte.
ScmCredential,
/// GSSAPI authentication is required.
/// The frontend must now initiate a GSSAPI negotiation. The frontend will send a
/// `GSSResponse` message with the first part of the GSSAPI data stream in response to this.
Gss,
/// SSPI authentication is required.
/// The frontend must now initiate a SSPI negotiation.
/// The frontend will send a GSSResponse with the first part of the SSPI data stream in
/// response to this.
Sspi,
/// This message contains GSSAPI or SSPI data.
GssContinue { data: Box<[u8]> },
/// This message contains the response data from the previous step of GSSAPI
/// or SSPI negotiation.
GssContinue,
/// SASL authentication is required.
///
/// The message body is a list of SASL authentication mechanisms,
/// in the server's order of preference.
Sasl { mechanisms: Box<[Box<str>]> },
/// The frontend must now initiate a SASL negotiation, using one of the SASL mechanisms
/// listed in the message.
Sasl,
/// This message contains a SASL challenge.
SaslContinue(SaslContinue),
/// This message contains challenge data from the previous step of SASL negotiation.
SaslContinue,
/// SASL authentication has completed.
SaslFinal { data: Box<[u8]> },
/// SASL authentication has completed with additional mechanism-specific data for the client.
SaslFinal,
}
impl Authentication {
pub fn read(mut buf: &[u8]) -> crate::Result<Self> {
Ok(match buf.get_u32::<NetworkEndian>()? {
0 => Authentication::Ok,
2 => Authentication::KerberosV5,
3 => Authentication::CleartextPassword,
5 => Authentication::Md5Password,
6 => Authentication::ScmCredential,
7 => Authentication::Gss,
8 => Authentication::GssContinue,
9 => Authentication::Sspi,
10 => Authentication::Sasl,
11 => Authentication::SaslContinue,
12 => Authentication::SaslFinal,
type_ => {
return Err(protocol_err!("unknown authentication message type: {}", type_).into());
}
})
}
}
#[derive(Debug)]
pub struct SaslContinue {
pub salt: Vec<u8>,
pub iter_count: u32,
pub nonce: Vec<u8>,
pub data: String,
pub struct AuthenticationMd5 {
pub salt: [u8; 4],
}
impl Decode for Authentication {
fn decode(mut buf: &[u8]) -> crate::Result<Self> {
Ok(match buf.get_u32::<NetworkEndian>()? {
0 => Authentication::Ok,
2 => Authentication::KerberosV5,
3 => Authentication::ClearTextPassword,
5 => {
impl AuthenticationMd5 {
pub fn read(mut buf: &[u8]) -> crate::Result<Self> {
let mut salt = [0_u8; 4];
salt.copy_from_slice(&buf);
salt.copy_from_slice(buf);
Authentication::Md5Password { salt }
}
6 => Authentication::ScmCredential,
7 => Authentication::Gss,
8 => {
let mut data = Vec::with_capacity(buf.len());
data.extend_from_slice(buf);
Authentication::GssContinue {
data: data.into_boxed_slice(),
Ok(Self { salt })
}
}
9 => Authentication::Sspi,
#[derive(Debug)]
pub struct AuthenticationSasl {
pub mechanisms: Box<[Box<str>]>,
}
10 => {
impl AuthenticationSasl {
pub fn read(mut buf: &[u8]) -> crate::Result<Self> {
let mut mechanisms = Vec::new();
while buf[0] != 0 {
mechanisms.push(buf.get_str_nul()?.into());
}
Authentication::Sasl {
Ok(Self {
mechanisms: mechanisms.into_boxed_slice(),
})
}
}
11 => {
#[derive(Debug)]
pub struct AuthenticationSaslContinue {
pub salt: Vec<u8>,
pub iter_count: u32,
pub nonce: Vec<u8>,
pub data: String,
}
impl AuthenticationSaslContinue {
pub fn read(mut buf: &[u8]) -> crate::Result<Self> {
let mut salt: Vec<u8> = Vec::new();
let mut nonce: Vec<u8> = Vec::new();
let mut iter_count: u32 = 0;
@ -125,32 +150,31 @@ impl Decode for Authentication {
}
}
Authentication::SaslContinue(SaslContinue {
Ok(Self {
salt: base64::decode(&salt).map_err(|_| {
protocol_err!("salt value response from postgres was not base64 encoded")
})?,
nonce,
iter_count,
data: str::from_utf8(buf)
.map_err(|_| {
protocol_err!("SaslContinue response was not a valid utf8 string")
})?
.map_err(|_| protocol_err!("SaslContinue response was not a valid utf8 string"))?
.to_string(),
})
}
}
12 => {
#[derive(Debug)]
pub struct AuthenticationSaslFinal {
pub data: Box<[u8]>,
}
impl AuthenticationSaslFinal {
pub fn read(mut buf: &[u8]) -> crate::Result<Self> {
let mut data = Vec::with_capacity(buf.len());
data.extend_from_slice(buf);
Authentication::SaslFinal {
Ok(Self {
data: data.into_boxed_slice(),
}
}
id => {
return Err(protocol_err!("unknown authentication response: {}", id).into());
}
})
}
}
@ -158,27 +182,25 @@ impl Decode for Authentication {
#[cfg(test)]
mod tests {
use super::{Authentication, Decode};
use crate::postgres::protocol::authentication::AuthenticationMd5;
use matches::assert_matches;
const AUTH_OK: &[u8] = b"\0\0\0\0";
const AUTH_MD5: &[u8] = b"\0\0\0\x05\x93\x189\x98";
#[test]
fn it_decodes_auth_ok() {
let m = Authentication::decode(AUTH_OK).unwrap();
fn it_reads_auth_ok() {
let m = Authentication::read(AUTH_OK).unwrap();
assert_matches!(m, Authentication::Ok);
}
#[test]
fn it_decodes_auth_md5_password() {
let m = Authentication::decode(AUTH_MD5).unwrap();
fn it_reads_auth_md5_password() {
let m = Authentication::read(AUTH_MD5).unwrap();
let data = AuthenticationMd5::read(&AUTH_MD5[4..]).unwrap();
assert_matches!(
m,
Authentication::Md5Password {
salt: [147, 24, 57, 152]
}
);
assert_matches!(m, Authentication::Md5Password);
assert_eq!(data.salt, [147, 24, 57, 152]);
}
}

View File

@ -6,8 +6,8 @@ pub struct CommandComplete {
pub affected_rows: u64,
}
impl Decode for CommandComplete {
fn decode(mut buf: &[u8]) -> crate::Result<Self> {
impl CommandComplete {
pub(crate) fn read(mut buf: &[u8]) -> crate::Result<Self> {
// Attempt to parse the last word in the command tag as an integer
// If it can't be parsed, the tag is probably "CREATE TABLE" or something
// and we should return 0 rows
@ -35,29 +35,29 @@ mod tests {
const COMMAND_COMPLETE_BEGIN: &[u8] = b"BEGIN\0";
#[test]
fn it_decodes_command_complete_for_insert() {
let message = CommandComplete::decode(COMMAND_COMPLETE_INSERT).unwrap();
fn it_reads_command_complete_for_insert() {
let message = CommandComplete::read(COMMAND_COMPLETE_INSERT).unwrap();
assert_eq!(message.affected_rows, 1);
}
#[test]
fn it_decodes_command_complete_for_update() {
let message = CommandComplete::decode(COMMAND_COMPLETE_UPDATE).unwrap();
fn it_reads_command_complete_for_update() {
let message = CommandComplete::read(COMMAND_COMPLETE_UPDATE).unwrap();
assert_eq!(message.affected_rows, 512);
}
#[test]
fn it_decodes_command_complete_for_begin() {
let message = CommandComplete::decode(COMMAND_COMPLETE_BEGIN).unwrap();
fn it_reads_command_complete_for_begin() {
let message = CommandComplete::read(COMMAND_COMPLETE_BEGIN).unwrap();
assert_eq!(message.affected_rows, 0);
}
#[test]
fn it_decodes_command_complete_for_create_table() {
let message = CommandComplete::decode(COMMAND_COMPLETE_CREATE_TABLE).unwrap();
fn it_reads_command_complete_for_create_table() {
let message = CommandComplete::read(COMMAND_COMPLETE_CREATE_TABLE).unwrap();
assert_eq!(message.affected_rows, 0);
}

View File

@ -1,34 +1,49 @@
use crate::io::{Buf, ByteStr};
use crate::postgres::protocol::Decode;
use crate::postgres::PgConnection;
use byteorder::NetworkEndian;
use std::fmt::{self, Debug};
use std::ops::Range;
pub struct DataRow {
buffer: Box<[u8]>,
values: Box<[Option<Range<u32>>]>,
len: u16,
}
impl DataRow {
pub fn len(&self) -> usize {
self.values.len()
self.len as usize
}
pub fn get(&self, index: usize) -> Option<&[u8]> {
let range = self.values[index].as_ref()?;
pub fn get<'a>(
&self,
buffer: &'a [u8],
values: &[Option<Range<u32>>],
index: usize,
) -> Option<&'a [u8]> {
let range = values[index].as_ref()?;
Some(&self.buffer[(range.start as usize)..(range.end as usize)])
Some(&buffer[(range.start as usize)..(range.end as usize)])
}
}
impl Decode for DataRow {
fn decode(mut buf: &[u8]) -> crate::Result<Self> {
let len = buf.get_u16::<NetworkEndian>()? as usize;
let buffer: Box<[u8]> = buf.into();
let mut values = Vec::with_capacity(len);
let mut index = 4;
impl DataRow {
pub(crate) fn read<'a>(
connection: &mut PgConnection,
// buffer: &'a [u8],
// values: &'a mut Vec<Option<Range<u32>>>,
) -> crate::Result<Self> {
let buffer = connection.stream.buffer();
let values = &mut connection.data_row_values_buf;
while values.len() < len {
values.clear();
let mut buf = buffer;
let len = buf.get_u16::<NetworkEndian>()?;
let mut index = 6;
while values.len() < (len as usize) {
// The length of the column value, in bytes (this count does not include itself).
// Can be zero. As a special case, -1 indicates a NULL column value.
// No value bytes follow in the NULL case.
@ -46,26 +61,7 @@ impl Decode for DataRow {
}
}
Ok(Self {
values: values.into_boxed_slice(),
buffer,
})
}
}
impl Debug for DataRow {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "DataRow(")?;
let len = self.values.len();
f.debug_list()
.entries((0..len).map(|i| self.get(i).map(ByteStr)))
.finish()?;
write!(f, ")")?;
Ok(())
Ok(Self { len })
}
}
@ -76,18 +72,14 @@ mod tests {
const DATA_ROW: &[u8] = b"\0\x03\0\0\0\x011\0\0\0\x012\0\0\0\x013";
#[test]
fn it_decodes_data_row() {
let m = DataRow::decode(DATA_ROW).unwrap();
fn it_reads_data_row() {
let mut values = Vec::new();
let m = DataRow::read(DATA_ROW, &mut values).unwrap();
assert_eq!(m.values.len(), 3);
assert_eq!(m.len, 3);
assert_eq!(m.get(0), Some(&b"1"[..]));
assert_eq!(m.get(1), Some(&b"2"[..]));
assert_eq!(m.get(2), Some(&b"3"[..]));
assert_eq!(
format!("{:?}", m),
"DataRow([Some(b\"1\"), Some(b\"2\"), Some(b\"3\")])"
);
assert_eq!(m.get(DATA_ROW, &values, 0), Some(&b"1"[..]));
assert_eq!(m.get(DATA_ROW, &values, 1), Some(&b"2"[..]));
assert_eq!(m.get(DATA_ROW, &values, 2), Some(&b"3"[..]));
}
}

View File

@ -1,24 +1,57 @@
use std::convert::TryFrom;
use crate::postgres::protocol::{
Authentication, BackendKeyData, CommandComplete, DataRow, NotificationResponse,
ParameterDescription, ParameterStatus, ReadyForQuery, Response, RowDescription,
};
#[derive(Debug)]
#[derive(Debug, Copy, Clone)]
#[repr(u8)]
pub enum Message {
Authentication(Box<Authentication>),
ParameterStatus(Box<ParameterStatus>),
BackendKeyData(BackendKeyData),
ReadyForQuery(ReadyForQuery),
CommandComplete(CommandComplete),
DataRow(DataRow),
Response(Box<Response>),
NotificationResponse(Box<NotificationResponse>),
ParseComplete,
Authentication,
BackendKeyData,
BindComplete,
CloseComplete,
CommandComplete,
DataRow,
NoData,
NotificationResponse,
ParameterDescription,
ParameterStatus,
ParseComplete,
PortalSuspended,
ParameterDescription(Box<ParameterDescription>),
RowDescription(Box<RowDescription>),
ReadyForQuery,
NoticeResponse,
ErrorResponse,
RowDescription,
}
impl TryFrom<u8> for Message {
type Error = crate::Error;
fn try_from(type_: u8) -> crate::Result<Self> {
// https://www.postgresql.org/docs/12/protocol-message-formats.html
Ok(match type_ {
b'E' => Message::ErrorResponse,
b'N' => Message::NoticeResponse,
b'D' => Message::DataRow,
b'S' => Message::ParameterStatus,
b'Z' => Message::ReadyForQuery,
b'R' => Message::Authentication,
b'K' => Message::BackendKeyData,
b'C' => Message::CommandComplete,
b'A' => Message::NotificationResponse,
b'1' => Message::ParseComplete,
b'2' => Message::BindComplete,
b'3' => Message::CloseComplete,
b'n' => Message::NoData,
b's' => Message::PortalSuspended,
b't' => Message::ParameterDescription,
b'T' => Message::RowDescription,
id => {
return Err(protocol_err!("unknown message: {:?}", id).into());
}
})
}
}

View File

@ -58,7 +58,10 @@ mod row_description;
mod message;
pub use authentication::Authentication;
pub use authentication::{
Authentication, AuthenticationMd5, AuthenticationSasl, AuthenticationSaslContinue,
AuthenticationSaslFinal,
};
pub use backend_key_data::BackendKeyData;
pub use command_complete::CommandComplete;
pub use data_row::DataRow;

View File

@ -7,8 +7,8 @@ pub struct ParameterDescription {
pub ids: Box<[TypeId]>,
}
impl Decode for ParameterDescription {
fn decode(mut buf: &[u8]) -> crate::Result<Self> {
impl ParameterDescription {
pub(crate) fn read(mut buf: &[u8]) -> crate::Result<Self> {
let cnt = buf.get_u16::<NetworkEndian>()? as usize;
let mut ids = Vec::with_capacity(cnt);
@ -27,9 +27,9 @@ mod test {
use super::{Decode, ParameterDescription};
#[test]
fn it_decodes_parameter_description() {
fn it_reads_parameter_description() {
let buf = b"\x00\x02\x00\x00\x00\x00\x00\x00\x05\x00";
let desc = ParameterDescription::decode(buf).unwrap();
let desc = ParameterDescription::read(buf).unwrap();
assert_eq!(desc.ids.len(), 2);
assert_eq!(desc.ids[0].0, 0x0000_0000);
@ -37,9 +37,9 @@ mod test {
}
#[test]
fn it_decodes_empty_parameter_description() {
fn it_reads_empty_parameter_description() {
let buf = b"\x00\x00";
let desc = ParameterDescription::decode(buf).unwrap();
let desc = ParameterDescription::read(buf).unwrap();
assert_eq!(desc.ids.len(), 0);
}

View File

@ -65,8 +65,8 @@ pub struct Response {
pub routine: Option<Box<str>>,
}
impl Decode for Response {
fn decode(mut buf: &[u8]) -> crate::Result<Self> {
impl Response {
pub fn read(mut buf: &[u8]) -> crate::Result<Self> {
let mut code = None::<Box<str>>;
let mut message = None::<Box<str>>;
let mut severity = None::<Box<str>>;

View File

@ -18,8 +18,8 @@ pub struct Field {
pub type_format: TypeFormat,
}
impl Decode for RowDescription {
fn decode(mut buf: &[u8]) -> crate::Result<Self> {
impl RowDescription {
pub(crate) fn read(mut buf: &[u8]) -> crate::Result<Self> {
let cnt = buf.get_u16::<NetworkEndian>()? as usize;
let mut fields = Vec::with_capacity(cnt);
@ -57,7 +57,7 @@ mod test {
use super::{Decode, RowDescription};
#[test]
fn it_decodes_row_description() {
fn it_reads_row_description() {
#[rustfmt::skip]
let buf = bytes! {
// Number of Parameters
@ -82,7 +82,7 @@ mod test {
0_u8, 0_u8 // format_code
};
let desc = RowDescription::decode(&buf).unwrap();
let desc = RowDescription::read(&buf).unwrap();
assert_eq!(desc.fields.len(), 2);
assert_eq!(desc.fields[0].type_id.0, 0x0000_0000);
@ -90,9 +90,9 @@ mod test {
}
#[test]
fn it_decodes_empty_row_description() {
fn it_reads_empty_row_description() {
let buf = b"\x00\x00";
let desc = RowDescription::decode(buf).unwrap();
let desc = RowDescription::read(buf).unwrap();
assert_eq!(desc.fields.len(), 0);
}

View File

@ -1,58 +1,58 @@
use std::collections::HashMap;
use std::sync::Arc;
use crate::connection::MaybeOwnedConnection;
use crate::decode::Decode;
use crate::pool::PoolConnection;
use crate::postgres::protocol::DataRow;
use crate::postgres::Postgres;
use crate::postgres::{PgConnection, Postgres};
use crate::row::{Row, RowIndex};
use crate::types::HasSqlType;
use crate::types::Type;
pub struct PgRow {
pub struct PgRow<'c> {
pub(super) connection: MaybeOwnedConnection<'c, PgConnection>,
pub(super) data: DataRow,
pub(super) columns: Arc<HashMap<Box<str>, usize>>,
}
impl Row for PgRow {
impl<'c> Row<'c> for PgRow<'c> {
type Database = Postgres;
fn len(&self) -> usize {
self.data.len()
}
fn get<T, I>(&self, index: I) -> T
fn try_get_raw<'i, I>(&'c self, index: I) -> crate::Result<Option<&'c [u8]>>
where
Self::Database: HasSqlType<T>,
I: RowIndex<Self>,
T: Decode<Self::Database>,
I: RowIndex<'c, Self> + 'i,
{
index.try_get(self).unwrap()
index.try_get_raw(self)
}
}
impl RowIndex<PgRow> for usize {
fn try_get<T>(&self, row: &PgRow) -> crate::Result<T>
where
<PgRow as Row>::Database: HasSqlType<T>,
T: Decode<<PgRow as Row>::Database>,
{
Ok(Decode::decode_nullable(row.data.get(*self))?)
impl<'c> RowIndex<'c, PgRow<'c>> for usize {
fn try_get_raw(self, row: &'c PgRow<'c>) -> crate::Result<Option<&'c [u8]>> {
Ok(row.data.get(
row.connection.stream.buffer(),
&row.connection.data_row_values_buf,
self,
))
}
}
impl RowIndex<PgRow> for &'_ str {
fn try_get<T>(&self, row: &PgRow) -> crate::Result<T>
where
<PgRow as Row>::Database: HasSqlType<T>,
T: Decode<<PgRow as Row>::Database>,
{
let index = row
.columns
.get(*self)
.ok_or_else(|| crate::Error::ColumnNotFound((*self).into()))?;
let value = Decode::decode_nullable(row.data.get(*index))?;
// impl<'c> RowIndex<'c, PgRow<'c>> for &'_ str {
// fn try_get_raw(self, row: &'r PgRow<'c>) -> crate::Result<Option<&'c [u8]>> {
// let index = row
// .columns
// .get(self)
// .ok_or_else(|| crate::Error::ColumnNotFound((*self).into()))?;
//
// Ok(row.data.get(
// row.connection.stream.buffer(),
// &row.connection.data_row_values_buf,
// *index,
// ))
// }
// }
Ok(value)
}
}
impl_from_row_for_row!(PgRow);
// TODO: impl_from_row_for_row!(PgRow);

View File

@ -3,8 +3,10 @@ use rand::Rng;
use sha2::{Digest, Sha256};
use crate::postgres::protocol::{
hi, Authentication, Encode, Message, SaslInitialResponse, SaslResponse,
hi, Authentication, AuthenticationSaslContinue, Encode, Message, SaslInitialResponse,
SaslResponse,
};
use crate::postgres::stream::PgStream;
use crate::postgres::PgConnection;
static GS2_HEADER: &'static str = "n,,";
@ -43,7 +45,7 @@ fn nonce() -> String {
// Performs authenticiton using Simple Authentication Security Layer (SASL) which is what
// Postgres uses
pub(super) async fn authenticate<T: AsRef<str>>(
conn: &mut PgConnection,
stream: &mut PgStream,
username: T,
password: T,
) -> crate::Result<()> {
@ -62,13 +64,18 @@ pub(super) async fn authenticate<T: AsRef<str>>(
client_first_message_bare = client_first_message_bare
);
SaslInitialResponse(&client_first_message).encode(conn.stream.buffer_mut());
conn.stream.flush().await?;
stream.write(SaslInitialResponse(&client_first_message));
stream.flush().await?;
let server_first_message = conn.receive().await?;
let server_first_message = stream.read().await?;
if let Message::Authentication = server_first_message {
let auth = Authentication::read(stream.buffer())?;
if let Authentication::SaslContinue = auth {
// todo: better way to indicate that we consumed just these 4 bytes?
let sasl = AuthenticationSaslContinue::read(&stream.buffer()[4..])?;
if let Some(Message::Authentication(auth)) = server_first_message {
if let Authentication::SaslContinue(sasl) = *auth {
let server_first_message = sasl.data;
// SaltedPassword := Hi(Normalize(password), salt, i)
@ -132,9 +139,11 @@ pub(super) async fn authenticate<T: AsRef<str>>(
client_proof = base64::encode(&client_proof)
);
SaslResponse(&client_final_message).encode(conn.stream.buffer_mut());
conn.stream.flush().await?;
let _server_final_response = conn.receive().await?;
stream.write(SaslResponse(&client_final_message));
stream.flush().await?;
let _server_final_response = stream.read().await?;
// todo: assert that this was SaslFinal?
Ok(())
} else {

View File

@ -0,0 +1,90 @@
use std::convert::TryInto;
use std::net::Shutdown;
use byteorder::NetworkEndian;
use crate::io::{Buf, BufStream, MaybeTlsStream};
use crate::postgres::protocol::{Encode, Message, Response};
use crate::postgres::PgError;
use crate::url::Url;
pub struct PgStream {
stream: BufStream<MaybeTlsStream>,
// Most recently received message
// Is referenced by our buffered stream
// Is initialized to ReadyForQuery/0 at the start
message: (Message, u32),
}
impl PgStream {
pub(super) async fn new(url: &Url) -> crate::Result<Self> {
let stream = MaybeTlsStream::connect(&url, 5432).await?;
Ok(Self {
stream: BufStream::new(stream),
message: (Message::ReadyForQuery, 0),
})
}
pub(super) fn shutdown(&self) -> crate::Result<()> {
Ok(self.stream.shutdown(Shutdown::Both)?)
}
#[inline]
pub(super) fn write<M>(&mut self, message: M)
where
M: Encode,
{
message.encode(self.stream.buffer_mut());
}
#[inline]
pub(super) async fn flush(&mut self) -> crate::Result<()> {
Ok(self.stream.flush().await?)
}
pub(super) async fn read(&mut self) -> crate::Result<Message> {
// https://www.postgresql.org/docs/12/protocol-overview.html#PROTOCOL-MESSAGE-CONCEPTS
// All communication is through a stream of messages. The first byte of a message
// identifies the message type, and the next four bytes specify the length of the rest of
// the message (this length count includes itself, but not the message-type byte).
if self.message.1 > 0 {
// If there is any data in our read buffer we need to make sure we flush that
// so reading will return the *next* message
self.stream.consume(self.message.1 as usize);
}
let mut header = self.stream.peek(4 + 1).await?;
let type_ = header.get_u8()?.try_into()?;
let length = header.get_u32::<NetworkEndian>()? - 4;
self.message = (type_, length);
self.stream.consume(4 + 1);
// Wait until there is enough data in the stream. We then return without actually
// inspecting the data. This is then looked at later through the [buffer] function
let _ = self.stream.peek(length as usize).await?;
if let Message::ErrorResponse = type_ {
// This is an error, bubble up as one immediately
return Err(crate::Error::Database(Box::new(PgError(Response::read(
self.stream.buffer(),
)?))));
}
Ok(type_)
}
/// Returns a reference to the internally buffered message.
///
/// This is the body of the message identified by the most recent call
/// to `read`.
#[inline]
pub(super) fn buffer(&self) -> &[u8] {
&self.stream.buffer()[..(self.message.1 as usize)]
}
}

View File

@ -3,15 +3,15 @@ use crate::encode::Encode;
use crate::postgres::protocol::TypeId;
use crate::postgres::types::PgTypeInfo;
use crate::postgres::Postgres;
use crate::types::HasSqlType;
use crate::types::Type;
impl HasSqlType<bool> for Postgres {
impl Type<Postgres> for bool {
fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::BOOL)
}
}
impl HasSqlType<[bool]> for Postgres {
impl Type<Postgres> for [bool] {
fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::ARRAY_BOOL)
}

View File

@ -3,24 +3,24 @@ use crate::encode::Encode;
use crate::postgres::protocol::TypeId;
use crate::postgres::types::PgTypeInfo;
use crate::postgres::Postgres;
use crate::types::HasSqlType;
use crate::types::Type;
impl HasSqlType<[u8]> for Postgres {
impl Type<Postgres> for [u8] {
fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::BYTEA)
}
}
impl HasSqlType<[&'_ [u8]]> for Postgres {
impl Type<Postgres> for [&'_ [u8]] {
fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::ARRAY_BYTEA)
}
}
// TODO: Do we need the [HasSqlType] here on the Vec?
impl HasSqlType<Vec<u8>> for Postgres {
impl Type<Postgres> for Vec<u8> {
fn type_info() -> PgTypeInfo {
<Self as HasSqlType<[u8]>>::type_info()
<[u8] as Type<Postgres>>::type_info()
}
}

View File

@ -8,27 +8,27 @@ use crate::encode::Encode;
use crate::postgres::protocol::TypeId;
use crate::postgres::types::PgTypeInfo;
use crate::postgres::Postgres;
use crate::types::HasSqlType;
use crate::types::Type;
impl HasSqlType<NaiveTime> for Postgres {
impl Type<Postgres> for NaiveTime {
fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::TIME)
}
}
impl HasSqlType<NaiveDate> for Postgres {
impl Type<Postgres> for NaiveDate {
fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::DATE)
}
}
impl HasSqlType<NaiveDateTime> for Postgres {
impl Type<Postgres> for NaiveDateTime {
fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::TIMESTAMP)
}
}
impl<Tz> HasSqlType<DateTime<Tz>> for Postgres
impl<Tz> Type<DateTime<Tz>> for Postgres
where
Tz: TimeZone,
{
@ -37,25 +37,25 @@ where
}
}
impl HasSqlType<[NaiveTime]> for Postgres {
impl Type<Postgres> for [NaiveTime] {
fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::ARRAY_TIME)
}
}
impl HasSqlType<[NaiveDate]> for Postgres {
impl Type<Postgres> for [NaiveDate] {
fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::ARRAY_DATE)
}
}
impl HasSqlType<[NaiveDateTime]> for Postgres {
impl Type<Postgres> for [NaiveDateTime] {
fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::ARRAY_TIMESTAMP)
}
}
impl<Tz> HasSqlType<[DateTime<Tz>]> for Postgres
impl<Tz> Type<[DateTime<Tz>]> for Postgres
where
Tz: TimeZone,
{

View File

@ -3,15 +3,15 @@ use crate::encode::Encode;
use crate::postgres::protocol::TypeId;
use crate::postgres::types::PgTypeInfo;
use crate::postgres::Postgres;
use crate::types::HasSqlType;
use crate::types::Type;
impl HasSqlType<f32> for Postgres {
impl Type<Postgres> for f32 {
fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::FLOAT4)
}
}
impl HasSqlType<[f32]> for Postgres {
impl Type<Postgres> for [f32] {
fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::ARRAY_FLOAT4)
}
@ -31,13 +31,13 @@ impl Decode<Postgres> for f32 {
}
}
impl HasSqlType<f64> for Postgres {
impl Type<Postgres> for f64 {
fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::FLOAT8)
}
}
impl HasSqlType<[f64]> for Postgres {
impl Type<Postgres> for [f64] {
fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::ARRAY_FLOAT8)
}

View File

@ -5,15 +5,15 @@ use crate::encode::Encode;
use crate::postgres::protocol::TypeId;
use crate::postgres::types::PgTypeInfo;
use crate::postgres::Postgres;
use crate::types::HasSqlType;
use crate::types::Type;
impl HasSqlType<i16> for Postgres {
impl Type<Postgres> for i16 {
fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::INT2)
}
}
impl HasSqlType<[i16]> for Postgres {
impl Type<Postgres> for [i16] {
fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::ARRAY_INT2)
}
@ -31,13 +31,13 @@ impl Decode<Postgres> for i16 {
}
}
impl HasSqlType<i32> for Postgres {
impl Type<Postgres> for i32 {
fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::INT4)
}
}
impl HasSqlType<[i32]> for Postgres {
impl Type<Postgres> for [i32] {
fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::ARRAY_INT4)
}
@ -55,13 +55,13 @@ impl Decode<Postgres> for i32 {
}
}
impl HasSqlType<i64> for Postgres {
impl Type<Postgres> for i64 {
fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::INT8)
}
}
impl HasSqlType<[i64]> for Postgres {
impl Type<Postgres> for [i64] {
fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::ARRAY_INT8)
}

View File

@ -4,25 +4,25 @@ use crate::decode::{Decode, DecodeError};
use crate::encode::Encode;
use crate::postgres::protocol::TypeId;
use crate::postgres::types::PgTypeInfo;
use crate::types::HasSqlType;
use crate::types::Type;
use crate::Postgres;
impl HasSqlType<str> for Postgres {
impl Type<Postgres> for str {
fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::TEXT)
}
}
impl HasSqlType<[&'_ str]> for Postgres {
impl Type<Postgres> for [&'_ str] {
fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::ARRAY_TEXT)
}
}
// TODO: Do we need [HasSqlType] on String here?
impl HasSqlType<String> for Postgres {
impl Type<Postgres> for String {
fn type_info() -> PgTypeInfo {
<Self as HasSqlType<str>>::type_info()
<str as Type<Postgres>>::type_info()
}
}

View File

@ -5,15 +5,15 @@ use crate::encode::Encode;
use crate::postgres::protocol::TypeId;
use crate::postgres::types::PgTypeInfo;
use crate::postgres::Postgres;
use crate::types::HasSqlType;
use crate::types::Type;
impl HasSqlType<Uuid> for Postgres {
impl Type<Postgres> for Uuid {
fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::UUID)
}
}
impl HasSqlType<[Uuid]> for Postgres {
impl Type<Postgres> for [Uuid] {
fn type_info() -> PgTypeInfo {
PgTypeInfo::new(TypeId::ARRAY_UUID)
}

View File

@ -4,68 +4,69 @@ use crate::cursor::Cursor;
use crate::database::{Database, HasCursor, HasRow};
use crate::encode::Encode;
use crate::executor::{Execute, Executor};
use crate::types::HasSqlType;
use crate::types::Type;
use futures_core::stream::BoxStream;
use futures_util::future::ready;
use futures_util::TryFutureExt;
use futures_util::TryStreamExt;
use std::future::Future;
use std::marker::PhantomData;
use std::mem;
/// Raw SQL query with bind parameters. Returned by [`query`].
pub struct Query<'a, DB, T = <DB as Database>::Arguments>
pub struct Query<'q, DB, T = <DB as Database>::Arguments>
where
DB: Database,
{
query: &'a str,
query: &'q str,
arguments: T,
database: PhantomData<DB>,
}
impl<'a, DB, P> Execute<'a, DB> for Query<'a, DB, P>
impl<'q, DB, P> Execute<'q, DB> for Query<'q, DB, P>
where
DB: Database,
P: IntoArguments<DB> + Send,
{
fn into_parts(self) -> (&'a str, Option<<DB as Database>::Arguments>) {
fn into_parts(self) -> (&'q str, Option<<DB as Database>::Arguments>) {
(self.query, Some(self.arguments.into_arguments()))
}
}
impl<'a, DB, P> Query<'a, DB, P>
impl<'q, DB, P> Query<'q, DB, P>
where
DB: Database,
P: IntoArguments<DB> + Send,
{
pub fn execute<'b, E>(self, executor: E) -> impl Future<Output = crate::Result<u64>> + 'b
pub async fn execute<'e, E>(self, executor: E) -> crate::Result<u64>
where
E: Executor<'b, Database = DB>,
'a: 'b,
E: Executor<'e, Database = DB>,
{
executor.execute(self).await
}
pub fn fetch<'e, E>(self, executor: E) -> <DB as HasCursor<'e, 'q>>::Cursor
where
E: Executor<'e, Database = DB>,
{
executor.execute(self)
}
pub fn fetch<'b, E>(self, executor: E) -> <DB as HasCursor<'b>>::Cursor
where
E: Executor<'b, Database = DB>,
'a: 'b,
{
executor.execute(self)
}
pub async fn fetch_optional<'b, E>(
pub async fn fetch_optional<'e, E>(
self,
executor: E,
) -> crate::Result<Option<<DB as HasRow>::Row>>
) -> crate::Result<Option<<DB as HasRow<'e>>::Row>>
where
E: Executor<'b, Database = DB>,
E: Executor<'e, Database = DB>,
'q: 'e,
{
executor.execute(self).first().await
}
pub async fn fetch_one<'b, E>(self, executor: E) -> crate::Result<<DB as HasRow>::Row>
pub async fn fetch_one<'e, E>(self, executor: E) -> crate::Result<<DB as HasRow<'e>>::Row>
where
E: Executor<'b, Database = DB>,
E: Executor<'e, Database = DB>,
'q: 'e,
{
self.fetch_optional(executor)
.and_then(|row| match row {
@ -83,7 +84,7 @@ where
/// Bind a value for use with this SQL query.
pub fn bind<T>(mut self, value: T) -> Self
where
DB: HasSqlType<T>,
T: Type<DB>,
T: Encode<DB>,
{
self.arguments.add(value);

View File

@ -2,20 +2,17 @@
use crate::database::Database;
use crate::decode::Decode;
use crate::types::HasSqlType;
use crate::types::Type;
pub trait RowIndex<R: ?Sized>
pub trait RowIndex<'c, R: ?Sized>
where
R: Row,
R: Row<'c>,
{
fn try_get<T>(&self, row: &R) -> crate::Result<T>
where
R::Database: HasSqlType<T>,
T: Decode<R::Database>;
fn try_get_raw(self, row: &'c R) -> crate::Result<Option<&'c [u8]>>;
}
/// Represents a single row of the result set.
pub trait Row: Unpin + Send + 'static {
pub trait Row<'c>: Unpin + Send {
type Database: Database + ?Sized;
/// Returns `true` if the row contains no values.
@ -26,18 +23,34 @@ pub trait Row: Unpin + Send + 'static {
/// Returns the number of values in the row.
fn len(&self) -> usize;
/// Returns the value at the `index`; can either be an integer ordinal or a column name.
fn get<T, I>(&self, index: I) -> T
fn get<T, I>(&'c self, index: I) -> T
where
Self::Database: HasSqlType<T>,
I: RowIndex<Self>,
T: Decode<Self::Database>;
T: Type<Self::Database>,
I: RowIndex<'c, Self>,
T: Decode<Self::Database>,
{
// todo: use expect with a proper message
self.try_get(index).unwrap()
}
fn try_get<T, I>(&'c self, index: I) -> crate::Result<T>
where
T: Type<Self::Database>,
I: RowIndex<'c, Self>,
T: Decode<Self::Database>,
{
Ok(Decode::decode_nullable(self.try_get_raw(index)?)?)
}
fn try_get_raw<'i, I>(&'c self, index: I) -> crate::Result<Option<&'c [u8]>>
where
I: RowIndex<'c, Self> + 'i;
}
/// A **record** that can be built from a row returned from by the database.
pub trait FromRow<R>
pub trait FromRow<'a, R>
where
R: Row,
R: Row<'a>,
{
fn from_row(row: R) -> Self;
}

View File

@ -4,6 +4,7 @@ use futures_core::future::BoxFuture;
use crate::connection::Connection;
use crate::database::HasCursor;
use crate::describe::Describe;
use crate::executor::{Execute, Executor};
use crate::runtime::spawn;
use crate::Database;
@ -19,10 +20,11 @@ where
depth: u32,
}
impl<T> Transaction<T>
impl<DB, T> Transaction<T>
where
T: Connection,
T: Executor<'static>,
T: Connection<Database = DB>,
DB: Database,
T: Executor<'static, Database = DB>,
{
pub(crate) async fn new(depth: u32, mut inner: T) -> crate::Result<Self> {
if depth == 0 {
@ -98,10 +100,11 @@ where
}
}
impl<T> Connection for Transaction<T>
impl<T, DB> Connection for Transaction<T>
where
T: Connection,
T: Executor<'static>,
T: Connection<Database = DB>,
DB: Database,
T: Executor<'static, Database = DB>,
{
type Database = <T as Connection>::Database;
@ -109,9 +112,23 @@ where
fn close(self) -> BoxFuture<'static, crate::Result<()>> {
Box::pin(async move { self.rollback().await?.close().await })
}
#[inline]
fn ping(&mut self) -> BoxFuture<crate::Result<()>> {
Box::pin(self.deref_mut().ping())
}
impl<'a, DB, T> Executor<'a> for &'a mut Transaction<T>
#[doc(hidden)]
#[inline]
fn describe<'e, 'q: 'e>(
&'e mut self,
query: &'q str,
) -> BoxFuture<'e, crate::Result<Describe<Self::Database>>> {
Box::pin(self.deref_mut().describe(query))
}
}
impl<'c, DB, T> Executor<'c> for &'c mut Transaction<T>
where
DB: Database,
T: Connection<Database = DB>,
@ -119,19 +136,19 @@ where
{
type Database = <T as Connection>::Database;
fn execute<'b, E>(self, query: E) -> <<T as Connection>::Database as HasCursor<'a>>::Cursor
fn execute<'q, E>(self, query: E) -> <<T as Connection>::Database as HasCursor<'c, 'q>>::Cursor
where
E: Execute<'b, Self::Database>,
E: Execute<'q, Self::Database>,
{
(**self).execute_by_ref(query)
}
fn execute_by_ref<'b, 'c, E>(
&'c mut self,
fn execute_by_ref<'q, 'e, E>(
&'e mut self,
query: E,
) -> <Self::Database as HasCursor<'c>>::Cursor
) -> <Self::Database as HasCursor<'e, 'q>>::Cursor
where
E: Execute<'b, Self::Database>,
E: Execute<'q, Self::Database>,
{
(**self).execute_by_ref(query)
}

View File

@ -21,29 +21,34 @@ pub trait TypeInfo: Debug + Display + Clone {
}
/// Indicates that a SQL type is supported for a database.
pub trait HasSqlType<T: ?Sized>: Database {
pub trait Type<DB>
where
DB: Database,
{
/// Returns the canonical type information on the database for the type `T`.
fn type_info() -> Self::TypeInfo;
fn type_info() -> DB::TypeInfo;
}
// For references to types in Rust, the underlying SQL type information
// is equivalent
impl<T: ?Sized, DB> HasSqlType<&'_ T> for DB
impl<T: ?Sized, DB> Type<DB> for &'_ T
where
DB: HasSqlType<T>,
DB: Database,
T: Type<DB>,
{
fn type_info() -> Self::TypeInfo {
<DB as HasSqlType<T>>::type_info()
fn type_info() -> DB::TypeInfo {
<T as Type<DB>>::type_info()
}
}
// For optional types in Rust, the underlying SQL type information
// is equivalent
impl<T, DB> HasSqlType<Option<T>> for DB
impl<T, DB> Type<DB> for Option<T>
where
DB: HasSqlType<T>,
DB: Database,
T: Type<DB>,
{
fn type_info() -> Self::TypeInfo {
<DB as HasSqlType<T>>::type_info()
fn type_info() -> DB::TypeInfo {
<T as Type<DB>>::type_info()
}
}

View File

@ -32,7 +32,7 @@ macro_rules! impl_database_ext {
$(
// `if` statements cannot have attributes but these can
$(#[$meta])?
_ if sqlx::types::TypeInfo::compatible(&<$database as sqlx::types::HasSqlType<$ty>>::type_info(), &info) => Some(input_ty!($ty $(, $input)?)),
_ if sqlx::types::TypeInfo::compatible(&<$ty as sqlx::types::Type<$database>>::type_info(), &info) => Some(input_ty!($ty $(, $input)?)),
)*
_ => None
}
@ -42,7 +42,7 @@ macro_rules! impl_database_ext {
match () {
$(
$(#[$meta])?
_ if sqlx::types::TypeInfo::compatible(&<$database as sqlx::types::HasSqlType<$ty>>::type_info(), &info) => return Some(stringify!($ty)),
_ if sqlx::types::TypeInfo::compatible(&<$ty as sqlx::types::Type<$database>>::type_info(), &info) => return Some(stringify!($ty)),
)*
_ => None
}

View File

@ -12,7 +12,7 @@ pub use sqlx_core::{arguments, describe, error, pool, row, types};
// Types
pub use sqlx_core::{
Connect, Connection, Database, Error, Executor, FromRow, Pool, Query, QueryAs, Result, Row,
Connect, Connection, Cursor, Database, Error, Executor, FromRow, Pool, Query, QueryAs, Result, Row,
Transaction,
};

View File

@ -1,59 +1,59 @@
use sqlx::{postgres::PgConnection, Connection as _, Row};
async fn connect() -> anyhow::Result<PgConnection> {
Ok(PgConnection::open(dotenv::var("DATABASE_URL")?).await?)
}
macro_rules! test {
($name:ident: $ty:ty: $($text:literal == $value:expr),+) => {
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
async fn $name () -> anyhow::Result<()> {
let mut conn = connect().await?;
$(
let row = sqlx::query(&format!("SELECT {} = $1, $1 as _1", $text))
.bind($value)
.fetch_one(&mut conn)
.await?;
assert!(row.get::<bool, _>(0));
assert!($value == row.get::<$ty, _>("_1"));
)+
Ok(())
}
}
}
test!(postgres_bool: bool: "false::boolean" == false, "true::boolean" == true);
test!(postgres_smallint: i16: "821::smallint" == 821_i16);
test!(postgres_int: i32: "94101::int" == 94101_i32);
test!(postgres_bigint: i64: "9358295312::bigint" == 9358295312_i64);
test!(postgres_real: f32: "9419.122::real" == 9419.122_f32);
test!(postgres_double: f64: "939399419.1225182::double precision" == 939399419.1225182_f64);
test!(postgres_text: String: "'this is foo'" == "this is foo", "''" == "");
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
async fn postgres_bytes() -> anyhow::Result<()> {
let mut conn = connect().await?;
let value = b"Hello, World";
let row = sqlx::query("SELECT E'\\\\x48656c6c6f2c20576f726c64' = $1, $1")
.bind(&value[..])
.fetch_one(&mut conn)
.await?;
assert!(row.get::<bool, _>(0));
let output: Vec<u8> = row.get(1);
assert_eq!(&value[..], &*output);
Ok(())
}
// use sqlx::{postgres::PgConnection, Connect as _, Connection as _, Row};
//
// async fn connect() -> anyhow::Result<PgConnection> {
// Ok(PgConnection::connect(dotenv::var("DATABASE_URL")?).await?)
// }
//
// macro_rules! test {
// ($name:ident: $ty:ty: $($text:literal == $value:expr),+) => {
// #[cfg_attr(feature = "runtime-async-std", async_std::test)]
// #[cfg_attr(feature = "runtime-tokio", tokio::test)]
// async fn $name () -> anyhow::Result<()> {
// let mut conn = connect().await?;
//
// $(
// let row = sqlx::query(&format!("SELECT {} = $1, $1 as _1", $text))
// .bind($value)
// .fetch_one(&mut conn)
// .await?;
//
// assert!(row.get::<bool, _>(0));
// assert!($value == row.get::<$ty, _>("_1"));
// )+
//
// Ok(())
// }
// }
// }
//
// test!(postgres_bool: bool: "false::boolean" == false, "true::boolean" == true);
//
// test!(postgres_smallint: i16: "821::smallint" == 821_i16);
// test!(postgres_int: i32: "94101::int" == 94101_i32);
// test!(postgres_bigint: i64: "9358295312::bigint" == 9358295312_i64);
//
// test!(postgres_real: f32: "9419.122::real" == 9419.122_f32);
// test!(postgres_double: f64: "939399419.1225182::double precision" == 939399419.1225182_f64);
//
// test!(postgres_text: String: "'this is foo'" == "this is foo", "''" == "");
//
// #[cfg_attr(feature = "runtime-async-std", async_std::test)]
// #[cfg_attr(feature = "runtime-tokio", tokio::test)]
// async fn postgres_bytes() -> anyhow::Result<()> {
// let mut conn = connect().await?;
//
// let value = b"Hello, World";
//
// let row = sqlx::query("SELECT E'\\\\x48656c6c6f2c20576f726c64' = $1, $1")
// .bind(&value[..])
// .fetch_one(&mut conn)
// .await?;
//
// assert!(row.get::<bool, _>(0));
//
// let output: Vec<u8> = row.get(1);
//
// assert_eq!(&value[..], &*output);
//
// Ok(())
// }

View File

@ -1,5 +1,5 @@
use futures::TryStreamExt;
use sqlx::{postgres::PgConnection, Connection as _, Executor as _, Row as _};
use sqlx::{postgres::PgConnection, Connect, Connection, Executor, Row};
use sqlx_core::postgres::PgPool;
use std::time::Duration;
@ -17,58 +17,40 @@ async fn it_connects() -> anyhow::Result<()> {
Ok(())
}
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
async fn it_executes() -> anyhow::Result<()> {
let mut conn = connect().await?;
let _ = conn
.send(
r#"
CREATE TEMPORARY TABLE users (id INTEGER PRIMARY KEY);
"#,
)
.await?;
for index in 1..=10_i32 {
let cnt = sqlx::query("INSERT INTO users (id) VALUES ($1)")
.bind(index)
.execute(&mut conn)
.await?;
assert_eq!(cnt, 1);
}
let sum: i32 = sqlx::query("SELECT id FROM users")
.fetch(&mut conn)
.try_fold(
0_i32,
|acc, x| async move { Ok(acc + x.get::<i32, _>("id")) },
)
.await?;
assert_eq!(sum, 55);
Ok(())
}
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
async fn it_remains_stable_issue_30() -> anyhow::Result<()> {
let mut conn = connect().await?;
// This tests the internal buffer wrapping around at the end
// Specifically: https://github.com/launchbadge/sqlx/issues/30
let rows = sqlx::query("SELECT i, random()::text FROM generate_series(1, 1000) as i")
.fetch_all(&mut conn)
.await?;
assert_eq!(rows.len(), 1000);
assert_eq!(rows[rows.len() - 1].get::<i32, _>(0), 1000);
Ok(())
}
// #[cfg_attr(feature = "runtime-async-std", async_std::test)]
// #[cfg_attr(feature = "runtime-tokio", tokio::test)]
// async fn it_executes() -> anyhow::Result<()> {
// let mut conn = connect().await?;
//
// let _ = conn
// .send(
// r#"
// CREATE TEMPORARY TABLE users (id INTEGER PRIMARY KEY);
// "#,
// )
// .await?;
//
// for index in 1..=10_i32 {
// let cnt = sqlx::query("INSERT INTO users (id) VALUES ($1)")
// .bind(index)
// .execute(&mut conn)
// .await?;
//
// assert_eq!(cnt, 1);
// }
//
// let sum: i32 = sqlx::query("SELECT id FROM users")
// .fetch(&mut conn)
// .try_fold(
// 0_i32,
// |acc, x| async move { Ok(acc + x.get::<i32, _>("id")) },
// )
// .await?;
//
// assert_eq!(sum, 55);
//
// Ok(())
// }
// https://github.com/launchbadge/sqlx/issues/104
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
@ -122,7 +104,7 @@ async fn pool_smoke_test() -> anyhow::Result<()> {
let pool = pool.clone();
spawn(async move {
loop {
if let Err(e) = sqlx::query("select 1 + 1").fetch_one(&mut &pool).await {
if let Err(e) = sqlx::query("select 1 + 1").fetch_one(&pool).await {
eprintln!("pool task {} dying due to {}", i, e);
break;
}
@ -159,5 +141,5 @@ async fn pool_smoke_test() -> anyhow::Result<()> {
async fn connect() -> anyhow::Result<PgConnection> {
let _ = dotenv::dotenv();
let _ = env_logger::try_init();
Ok(PgConnection::open(dotenv::var("DATABASE_URL")?).await?)
Ok(PgConnection::connect(dotenv::var("DATABASE_URL")?).await?)
}