Remove the RawConnection concept and fold into Backend

This commit is contained in:
Ryan Leckey
2019-11-22 11:46:49 +00:00
parent c7ce15f977
commit 061b7819ab
16 changed files with 608 additions and 776 deletions

View File

@@ -1,20 +1,61 @@
use crate::{connection::RawConnection, query::QueryParameters, row::Row, types::HasTypeMetadata};
use crate::describe::Describe;
use crate::{query::QueryParameters, row::Row, types::HasTypeMetadata};
use async_trait::async_trait;
use futures_core::stream::BoxStream;
/// A database backend.
///
/// This trait represents the concept of a backend (e.g. "MySQL" vs "SQLite").
pub trait Backend: HasTypeMetadata + Sized {
/// Represents a connection to the database and further provides auxillary but
/// important related traits as associated types.
///
/// This trait is not intended to be used directly.
/// Instead [sqlx::Connection] or [sqlx::Pool] should be used instead,
/// which provide concurrent access and typed retrieval of results.
#[async_trait]
pub trait Backend: HasTypeMetadata + Send + Sync + Sized {
/// The concrete `QueryParameters` implementation for this backend.
type QueryParameters: QueryParameters<Backend = Self>;
/// The concrete `RawConnection` implementation for this backend.
type RawConnection: RawConnection<Backend = Self>;
/// The concrete `Row` implementation for this backend. This type is returned
/// from methods in the `RawConnection`.
/// The concrete `Row` implementation for this backend.
type Row: Row<Backend = Self>;
/// The identifier for tables; in Postgres this is an `oid` while
/// in MariaDB/MySQL this is the qualified name of the table.
type TableIdent;
/// Establish a new connection to the database server.
async fn open(url: &str) -> crate::Result<Self>
where
Self: Sized;
/// Release resources for this database connection immediately.
///
/// This method is not required to be called. A database server will
/// eventually notice and clean up not fully closed connections.
async fn close(mut self) -> crate::Result<()>;
async fn ping(&mut self) -> crate::Result<()> {
// TODO: Does this need to be specialized for any database backends?
let _ = self
.execute("SELECT 1", Self::QueryParameters::new())
.await?;
Ok(())
}
async fn describe(&mut self, query: &str) -> crate::Result<Describe<Self>>;
async fn execute(&mut self, query: &str, params: Self::QueryParameters) -> crate::Result<u64>;
fn fetch(
&mut self,
query: &str,
params: Self::QueryParameters,
) -> BoxStream<'_, crate::Result<Self::Row>>;
async fn fetch_optional(
&mut self,
query: &str,
params: Self::QueryParameters,
) -> crate::Result<Option<Self::Row>>;
}

View File

@@ -4,77 +4,12 @@ use crate::{
error::Error,
executor::Executor,
pool::{Live, SharedPool},
query::{IntoQueryParameters, QueryParameters},
query::IntoQueryParameters,
row::FromSqlRow,
};
use async_trait::async_trait;
use crossbeam_queue::SegQueue;
use crossbeam_utils::atomic::AtomicCell;
use futures_channel::oneshot::{channel, Sender};
use futures_core::{future::BoxFuture, stream::BoxStream};
use futures_util::stream::StreamExt;
use std::{
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
time::Instant,
};
/// A connection to the database.
///
/// This trait is not intended to be used directly. Instead [sqlx::Connection] or [sqlx::Pool] should be used instead, which provide
/// concurrent access and typed retrieval of results.
#[async_trait]
pub trait RawConnection: Send + Sync {
// The database backend this type connects to.
type Backend: Backend;
/// Establish a new connection to the database server.
async fn establish(url: &str) -> crate::Result<Self>
where
Self: Sized;
/// Release resources for this database connection immediately.
///
/// This method is not required to be called. A database server will eventually notice
/// and clean up not fully closed connections.
///
/// It is safe to close an already closed connection.
async fn close(mut self) -> crate::Result<()>;
/// Verifies a connection to the database is still alive.
async fn ping(&mut self) -> crate::Result<()> {
let _ = self
.execute(
"SELECT 1",
<<Self::Backend as Backend>::QueryParameters>::new(),
)
.await?;
Ok(())
}
async fn execute(
&mut self,
query: &str,
params: <Self::Backend as Backend>::QueryParameters,
) -> crate::Result<u64>;
fn fetch(
&mut self,
query: &str,
params: <Self::Backend as Backend>::QueryParameters,
) -> BoxStream<'_, crate::Result<<Self::Backend as Backend>::Row>>;
async fn fetch_optional(
&mut self,
query: &str,
params: <Self::Backend as Backend>::QueryParameters,
) -> crate::Result<Option<<Self::Backend as Backend>::Row>>;
async fn describe(&mut self, query: &str) -> crate::Result<Describe<Self::Backend>>;
}
use std::{sync::Arc, time::Instant};
pub struct Connection<DB>
where
@@ -95,8 +30,8 @@ where
}
}
pub async fn establish(url: &str) -> crate::Result<Self> {
let raw = <DB as Backend>::RawConnection::establish(url).await?;
pub async fn open(url: &str) -> crate::Result<Self> {
let raw = DB::open(url).await?;
let live = Live {
raw,
since: Instant::now(),

View File

@@ -13,7 +13,9 @@ pub struct Describe<DB: Backend> {
}
impl<DB: Backend> fmt::Debug for Describe<DB>
where <DB as HasTypeMetadata>::TypeId: fmt::Debug, ResultField<DB>: fmt::Debug
where
<DB as HasTypeMetadata>::TypeId: fmt::Debug,
ResultField<DB>: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Describe")
@@ -31,7 +33,9 @@ pub struct ResultField<DB: Backend> {
}
impl<DB: Backend> fmt::Debug for ResultField<DB>
where <DB as Backend>::TableIdent: fmt::Debug, <DB as HasTypeMetadata>::TypeId: fmt::Debug
where
<DB as Backend>::TableIdent: fmt::Debug,
<DB as HasTypeMetadata>::TypeId: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("ResultField")

View File

@@ -1,7 +1,9 @@
use async_std::io::{
prelude::{ReadExt, WriteExt},
Read, Write,
};
use bytes::{BufMut, BytesMut};
use std::io;
use async_std::io::{Read, Write, prelude::{ReadExt, WriteExt}};
use async_std::future::poll_fn;
pub struct BufStream<S> {
pub(crate) stream: S,

View File

@@ -1,13 +1,87 @@
use super::{MariaDb, MariaDbQueryParameters, MariaDbRow};
use crate::backend::Backend;
use crate::describe::{Describe, ResultField};
use crate::mariadb::protocol::ColumnDefinitionPacket;
use async_trait::async_trait;
use futures_core::stream::BoxStream;
#[derive(Debug)]
pub struct MariaDb;
#[async_trait]
impl Backend for MariaDb {
type QueryParameters = super::MariaDbQueryParameters;
type RawConnection = super::MariaDbRawConnection;
type Row = super::MariaDbRow;
type QueryParameters = MariaDbQueryParameters;
type Row = MariaDbRow;
type TableIdent = String;
async fn open(url: &str) -> crate::Result<Self>
where
Self: Sized,
{
MariaDb::open(url).await
}
async fn close(mut self) -> crate::Result<()> {
self.close().await
}
async fn ping(&mut self) -> crate::Result<()> {
self.ping().await
}
async fn execute(&mut self, query: &str, params: MariaDbQueryParameters) -> crate::Result<u64> {
// Write prepare statement to buffer
self.start_sequence();
let prepare_ok = self.send_prepare(query).await?;
let affected = self.execute(prepare_ok.statement_id, params).await?;
Ok(affected)
}
fn fetch(
&mut self,
_query: &str,
_params: MariaDbQueryParameters,
) -> BoxStream<'_, crate::Result<MariaDbRow>> {
unimplemented!();
}
async fn fetch_optional(
&mut self,
_query: &str,
_params: MariaDbQueryParameters,
) -> crate::Result<Option<Self::Row>> {
unimplemented!();
}
async fn describe(&mut self, query: &str) -> crate::Result<Describe<MariaDb>> {
let prepare_ok = self.send_prepare(query).await?;
let mut param_types = Vec::with_capacity(prepare_ok.params as usize);
for _ in 0..prepare_ok.params {
let param = ColumnDefinitionPacket::decode(self.receive().await?)?;
param_types.push(param.field_type.0);
}
self.check_eof().await?;
let mut columns = Vec::with_capacity(prepare_ok.columns as usize);
for _ in 0..prepare_ok.columns {
let column = ColumnDefinitionPacket::decode(self.receive().await?)?;
columns.push(ResultField {
name: column.column_alias.or(column.column),
table_id: column.table_alias.or(column.table),
type_id: column.field_type.0,
})
}
self.check_eof().await?;
Ok(Describe {
param_types,
result_fields: columns,
})
}
}
impl_from_sql_row_tuples_for_backend!(MariaDb);

View File

@@ -1,8 +1,5 @@
use super::establish;
use crate::{
connection::RawConnection,
describe::{Describe, ResultField},
error::DatabaseError,
io::{Buf, BufMut, BufStream},
mariadb::{
protocol::{
@@ -10,31 +7,27 @@ use crate::{
ComStmtExecute, ComStmtPrepare, ComStmtPrepareOk, Encode, EofPacket, ErrPacket,
OkPacket, ResultRow, StmtExecFlag,
},
MariaDb, MariaDbQueryParameters, MariaDbRow,
MariaDbQueryParameters,
},
Backend, Error, Result,
Error, Result,
};
use async_trait::async_trait;
use async_std::net::TcpStream;
use byteorder::{ByteOrder, LittleEndian};
use futures_core::{future::BoxFuture, stream::BoxStream};
use futures_util::stream::{self, StreamExt};
use std::{
future::Future,
io,
net::{IpAddr, SocketAddr},
};
use async_std::net::TcpStream;
use url::Url;
pub struct MariaDbRawConnection {
pub struct MariaDb {
pub(crate) stream: BufStream<TcpStream>,
pub(crate) rbuf: Vec<u8>,
pub(crate) capabilities: Capabilities,
next_seq_no: u8,
}
impl MariaDbRawConnection {
async fn establish(url: &str) -> Result<Self> {
impl MariaDb {
pub async fn open(url: &str) -> Result<Self> {
// TODO: Handle errors
let url = Url::parse(url).unwrap();
@@ -110,7 +103,7 @@ impl MariaDbRawConnection {
Ok(Some(&self.rbuf[..len]))
}
fn start_sequence(&mut self) {
pub(super) fn start_sequence(&mut self) {
// At the start of a command sequence we reset our understanding
// of [next_seq_no]. In a sequence our initial command must be 0, followed
// by the server response that is 1, then our response to that response (if any),
@@ -147,7 +140,7 @@ impl MariaDbRawConnection {
// to terminate immediately
pub(crate) async fn receive_ok_or_err(&mut self) -> Result<OkPacket> {
let capabilities = self.capabilities;
let mut buf = self.receive().await?;
let buf = self.receive().await?;
Ok(match buf[0] {
0xfe | 0x00 => OkPacket::decode(buf, capabilities)?,
@@ -175,7 +168,7 @@ impl MariaDbRawConnection {
})
}
async fn check_eof(&mut self) -> Result<()> {
pub(super) async fn check_eof(&mut self) -> Result<()> {
if !self
.capabilities
.contains(Capabilities::CLIENT_DEPRECATE_EOF)
@@ -186,7 +179,10 @@ impl MariaDbRawConnection {
Ok(())
}
async fn send_prepare<'c>(&'c mut self, statement: &'c str) -> Result<ComStmtPrepareOk> {
pub(super) async fn send_prepare<'c>(
&'c mut self,
statement: &'c str,
) -> Result<ComStmtPrepareOk> {
self.stream.flush().await?;
self.start_sequence();
@@ -204,7 +200,11 @@ impl MariaDbRawConnection {
ComStmtPrepareOk::decode(packet).map_err(Into::into)
}
async fn execute(&mut self, statement_id: u32, params: MariaDbQueryParameters) -> Result<u64> {
pub(super) async fn execute(
&mut self,
statement_id: u32,
_params: MariaDbQueryParameters,
) -> Result<u64> {
// TODO: EXECUTE(READ_ONLY) => FETCH instead of EXECUTE(NO)
// SEND ================
@@ -279,130 +279,3 @@ impl MariaDbRawConnection {
Ok(rows)
}
}
enum ExecResult {
NoRows(OkPacket),
Rows(Vec<ColumnDefinitionPacket>),
}
#[async_trait]
impl RawConnection for MariaDbRawConnection {
type Backend = MariaDb;
async fn establish(url: &str) -> crate::Result<Self>
where
Self: Sized,
{
MariaDbRawConnection::establish(url).await
}
async fn close(mut self) -> crate::Result<()> {
self.close().await
}
async fn ping(&mut self) -> crate::Result<()> {
self.ping().await
}
async fn execute(&mut self, query: &str, params: MariaDbQueryParameters) -> crate::Result<u64> {
// Write prepare statement to buffer
self.start_sequence();
let prepare_ok = self.send_prepare(query).await?;
let affected = self.execute(prepare_ok.statement_id, params).await?;
Ok(affected)
}
fn fetch(
&mut self,
query: &str,
params: MariaDbQueryParameters,
) -> BoxStream<'_, Result<MariaDbRow>> {
unimplemented!();
}
async fn fetch_optional(
&mut self,
query: &str,
params: MariaDbQueryParameters,
) -> crate::Result<Option<<Self::Backend as Backend>::Row>> {
unimplemented!();
}
async fn describe(&mut self, query: &str) -> crate::Result<Describe<MariaDb>> {
let prepare_ok = self.send_prepare(query).await?;
let mut param_types = Vec::with_capacity(prepare_ok.params as usize);
for _ in 0..prepare_ok.params {
let param = ColumnDefinitionPacket::decode(self.receive().await?)?;
param_types.push(param.field_type.0);
}
self.check_eof().await?;
let mut columns = Vec::with_capacity(prepare_ok.columns as usize);
for _ in 0..prepare_ok.columns {
let column = ColumnDefinitionPacket::decode(self.receive().await?)?;
columns.push(ResultField {
name: column.column_alias.or(column.column),
table_id: column.table_alias.or(column.table),
type_id: column.field_type.0,
})
}
self.check_eof().await?;
Ok(Describe {
param_types,
result_fields: columns,
})
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::{query::QueryParameters, Error, Pool};
#[async_std::test]
async fn it_can_connect() -> Result<()> {
MariaDbRawConnection::establish("mariadb://root@127.0.0.1:3306/test").await?;
Ok(())
}
#[async_std::test]
async fn it_fails_to_connect_with_bad_username() -> Result<()> {
match MariaDbRawConnection::establish("mariadb://roote@127.0.0.1:3306/test").await {
Ok(_) => panic!("Somehow connected to database with incorrect username"),
Err(_) => Ok(()),
}
}
#[async_std::test]
async fn it_can_ping() -> Result<()> {
let mut conn =
MariaDbRawConnection::establish("mariadb://root@127.0.0.1:3306/test").await?;
conn.ping().await?;
Ok(())
}
#[async_std::test]
async fn it_can_describe() -> Result<()> {
let mut conn =
MariaDbRawConnection::establish("mysql://sqlx_user@127.0.0.1:3306/sqlx_test").await?;
let describe = conn.describe("SELECT id from accounts where id = ?").await?;
dbg!(describe);
Ok(())
}
#[async_std::test]
async fn it_can_create_mariadb_pool() -> Result<()> {
let pool: Pool<MariaDb> = Pool::new("mariadb://root@127.0.0.1:3306/test").await?;
Ok(())
}
}

View File

@@ -1,13 +1,13 @@
use crate::{
mariadb::{
connection::MariaDbRawConnection,
protocol::{Capabilities, Encode, HandshakeResponsePacket, InitialHandshakePacket},
connection::MariaDb,
protocol::{Capabilities, HandshakeResponsePacket, InitialHandshakePacket},
},
Result,
};
use url::Url;
pub(crate) async fn establish(conn: &mut MariaDbRawConnection, url: &Url) -> Result<()> {
pub(crate) async fn establish(conn: &mut MariaDb, url: &Url) -> Result<()> {
let initial = InitialHandshakePacket::decode(conn.receive().await?)?;
// TODO: Capabilities::SECURE_CONNECTION
@@ -17,7 +17,7 @@ pub(crate) async fn establish(conn: &mut MariaDbRawConnection, url: &Url) -> Res
// TODO: Capabilities::TRANSACTIONS
// TODO: Capabilities::CLIENT_DEPRECATE_EOF
// TODO?: Capabilities::CLIENT_SESSION_TRACK
let mut capabilities = Capabilities::CLIENT_PROTOCOL_41 | Capabilities::CONNECT_WITH_DB;
let capabilities = Capabilities::CLIENT_PROTOCOL_41 | Capabilities::CONNECT_WITH_DB;
let response = HandshakeResponsePacket {
// TODO: Find a good value for [max_packet_size]

View File

@@ -168,6 +168,7 @@ mod tests {
assert_eq!(&buf[..], b"\xFC\xFE\x00");
}
#[test]
fn it_encodes_int_lenenc_ff() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc::<LittleEndian, _>(0xFF as u64);

View File

@@ -8,7 +8,4 @@ mod query;
mod row;
pub mod types;
pub use self::{
backend::MariaDb, connection::MariaDbRawConnection, query::MariaDbQueryParameters,
row::MariaDbRow,
};
pub use self::{connection::MariaDb, query::MariaDbQueryParameters, row::MariaDbRow};

View File

@@ -1,10 +1,6 @@
use crate::{
backend::Backend,
connection::{Connection, RawConnection},
error::Error,
executor::Executor,
query::IntoQueryParameters,
row::FromSqlRow,
backend::Backend, connection::Connection, error::Error, executor::Executor,
query::IntoQueryParameters, row::FromSqlRow,
};
use crossbeam_queue::{ArrayQueue, SegQueue};
use futures_channel::oneshot;
@@ -238,7 +234,7 @@ where
if self.size.compare_and_swap(size, size + 1, Ordering::AcqRel) == size {
// Open a new connection and return directly
let raw = <DB as Backend>::RawConnection::establish(&self.url).await?;
let raw = DB::open(&self.url).await?;
let live = Live {
raw,
since: Instant::now(),
@@ -404,7 +400,7 @@ pub(crate) struct Live<DB>
where
DB: Backend,
{
pub(crate) raw: DB::RawConnection,
pub(crate) raw: DB,
#[allow(unused)]
pub(crate) since: Instant,
}

View File

@@ -1,13 +1,149 @@
use super::connection::Step;
use super::Postgres;
use super::PostgresQueryParameters;
use super::PostgresRow;
use crate::backend::Backend;
use crate::describe::{Describe, ResultField};
use crate::query::QueryParameters;
use crate::url::Url;
use async_trait::async_trait;
use futures_core::stream::BoxStream;
#[derive(Debug)]
pub struct Postgres;
#[async_trait]
impl Backend for Postgres {
type QueryParameters = super::PostgresQueryParameters;
type RawConnection = super::PostgresRawConnection;
type Row = super::PostgresRow;
type QueryParameters = PostgresQueryParameters;
type Row = PostgresRow;
type TableIdent = u32;
async fn open(url: &str) -> crate::Result<Self> {
let url = Url::parse(url)?;
let address = url.resolve(5432);
let mut conn = Self::new(address).await?;
conn.startup(
url.username(),
url.password().unwrap_or_default(),
url.database(),
)
.await?;
Ok(conn)
}
#[inline]
async fn close(mut self) -> crate::Result<()> {
self.terminate().await
}
async fn execute(
&mut self,
query: &str,
params: PostgresQueryParameters,
) -> crate::Result<u64> {
self.parse("", query, &params);
self.bind("", "", &params);
self.execute("", 1);
self.sync().await?;
let mut affected = 0;
while let Some(step) = self.step().await? {
if let Step::Command(cnt) = step {
affected = cnt;
}
}
Ok(affected)
}
fn fetch(
&mut self,
query: &str,
params: PostgresQueryParameters,
) -> BoxStream<'_, crate::Result<PostgresRow>> {
self.parse("", query, &params);
self.bind("", "", &params);
self.execute("", 0);
Box::pin(async_stream::try_stream! {
self.sync().await?;
while let Some(step) = self.step().await? {
if let Step::Row(row) = step {
yield row;
}
}
})
}
async fn fetch_optional(
&mut self,
query: &str,
params: PostgresQueryParameters,
) -> crate::Result<Option<PostgresRow>> {
self.parse("", query, &params);
self.bind("", "", &params);
self.execute("", 2);
self.sync().await?;
let mut row: Option<PostgresRow> = None;
while let Some(step) = self.step().await? {
if let Step::Row(r) = step {
if row.is_some() {
return Err(crate::Error::FoundMoreThanOne);
}
row = Some(r);
}
}
Ok(row)
}
async fn describe(&mut self, body: &str) -> crate::Result<Describe<Postgres>> {
self.parse("", body, &PostgresQueryParameters::new());
self.describe("");
self.sync().await?;
let param_desc = loop {
let step = self
.step()
.await?
.ok_or(invalid_data!("did not receive ParameterDescription"));
if let Step::ParamDesc(desc) = step? {
break desc;
}
};
let row_desc = loop {
let step = self
.step()
.await?
.ok_or(invalid_data!("did not receive RowDescription"));
if let Step::RowDesc(desc) = step? {
break desc;
}
};
Ok(Describe {
param_types: param_desc.ids.into_vec(),
result_fields: row_desc
.fields
.into_vec()
.into_iter()
.map(|field| ResultField {
name: Some(field.name),
table_id: Some(field.table_id),
type_id: field.type_id,
})
.collect(),
})
}
}
impl_from_sql_row_tuples_for_backend!(Postgres);

View File

@@ -1,226 +1,307 @@
use super::{Postgres, PostgresQueryParameters, PostgresRawConnection, PostgresRow};
use crate::{
connection::RawConnection,
describe::{Describe, ResultField},
postgres::raw::Step,
query::QueryParameters,
url::Url,
io::{Buf, BufStream},
postgres::{
protocol::{self, Decode, Encode, Message},
PostgresDatabaseError, PostgresQueryParameters, PostgresRow,
},
};
use async_trait::async_trait;
use futures_core::stream::BoxStream;
use async_std::net::TcpStream;
use byteorder::NetworkEndian;
use std::net::Shutdown;
use std::{io, net::SocketAddr};
use std::sync::atomic::{AtomicU64, Ordering};
pub struct Postgres {
stream: BufStream<TcpStream>,
use crate::postgres::{protocol::Message, PostgresDatabaseError};
use std::hash::Hasher;
// Process ID of the Backend
process_id: u32,
#[async_trait]
impl RawConnection for PostgresRawConnection {
type Backend = Postgres;
async fn establish(url: &str) -> crate::Result<Self> {
let url = Url::parse(url)?;
let address = url.resolve(5432);
let mut conn = Self::new(address).await?;
conn.startup(
url.username(),
url.password().unwrap_or_default(),
url.database(),
)
.await?;
Ok(conn)
}
#[inline]
async fn close(mut self) -> crate::Result<()> {
self.terminate().await
}
async fn execute(
&mut self,
query: &str,
params: PostgresQueryParameters,
) -> crate::Result<u64> {
self.parse("", query, &params);
self.bind("", "", &params);
self.execute("", 1);
self.sync().await?;
let mut affected = 0;
while let Some(step) = self.step().await? {
if let Step::Command(cnt) = step {
affected = cnt;
}
}
Ok(affected)
}
fn fetch(
&mut self,
query: &str,
params: PostgresQueryParameters,
) -> BoxStream<'_, crate::Result<PostgresRow>> {
self.parse("", query, &params);
self.bind("", "", &params);
self.execute("", 0);
Box::pin(async_stream::try_stream! {
self.sync().await?;
while let Some(step) = self.step().await? {
if let Step::Row(row) = step {
yield row;
}
}
})
}
async fn fetch_optional(
&mut self,
query: &str,
params: PostgresQueryParameters,
) -> crate::Result<Option<PostgresRow>> {
self.parse("", query, &params);
self.bind("", "", &params);
self.execute("", 2);
self.sync().await?;
let mut row: Option<PostgresRow> = None;
while let Some(step) = self.step().await? {
if let Step::Row(r) = step {
if row.is_some() {
return Err(crate::Error::FoundMoreThanOne);
}
row = Some(r);
}
}
Ok(row)
}
async fn describe(&mut self, body: &str) -> crate::Result<Describe<Postgres>> {
self.parse("", body, &PostgresQueryParameters::new());
self.describe("");
self.sync().await?;
let param_desc = loop {
let step = self
.step()
.await?
.ok_or(invalid_data!("did not receive ParameterDescription"));
if let Step::ParamDesc(desc) = step? {
break desc;
}
};
let row_desc = loop {
let step = self
.step()
.await?
.ok_or(invalid_data!("did not receive RowDescription"));
if let Step::RowDesc(desc) = step? {
break desc;
}
};
Ok(Describe {
param_types: param_desc.ids.into_vec(),
result_fields: row_desc
.fields
.into_vec()
.into_iter()
.map(|field| ResultField {
name: Some(field.name),
table_id: Some(field.table_id),
type_id: field.type_id,
})
.collect(),
})
}
// Backend-unique key to use to send a cancel query message to the server
secret_key: u32,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::query::QueryParameters;
use std::env;
// [x] 52.2.1. Start-up
// [ ] 52.2.2. Simple Query
// [ ] 52.2.3. Extended Query
// [ ] 52.2.4. Function Call
// [ ] 52.2.5. COPY Operations
// [ ] 52.2.6. Asynchronous Operations
// [ ] 52.2.7. Canceling Requests in Progress
// [x] 52.2.8. Termination
// [ ] 52.2.9. SSL Session Encryption
// [ ] 52.2.10. GSSAPI Session Encryption
fn database_url() -> String {
env::var("POSTGRES_DATABASE_URL")
.or_else(|_| env::var("DATABASE_URL"))
.unwrap()
impl Postgres {
pub(super) async fn new(address: SocketAddr) -> crate::Result<Self> {
let stream = TcpStream::connect(&address).await?;
Ok(Self {
stream: BufStream::new(stream),
process_id: 0,
secret_key: 0,
})
}
#[async_std::test]
#[ignore]
async fn it_establishes() -> crate::Result<()> {
let mut conn = PostgresRawConnection::establish(&database_url()).await?;
// https://www.postgresql.org/docs/devel/protocol-flow.html#id-1.10.5.7.3
pub(super) async fn startup(
&mut self,
username: &str,
password: &str,
database: &str,
) -> crate::Result<()> {
// See this doc for more runtime parameters
// https://www.postgresql.org/docs/12/runtime-config-client.html
let params = &[
("user", username),
("database", database),
// Sets the display format for date and time values,
// as well as the rules for interpreting ambiguous date input values.
("DateStyle", "ISO, MDY"),
// Sets the display format for interval values.
("IntervalStyle", "iso_8601"),
// Sets the time zone for displaying and interpreting time stamps.
("TimeZone", "UTC"),
// Adjust postgres to return percise values for floats
// NOTE: This is default in postgres 12+
("extra_float_digits", "3"),
// Sets the client-side encoding (character set).
("client_encoding", "UTF-8"),
];
// After establish, run PING to ensure that it was established correctly
conn.ping().await?;
protocol::StartupMessage { params }.encode(self.stream.buffer_mut());
self.stream.flush().await?;
// Then explicitly close the connection
conn.close().await?;
while let Some(message) = self.receive().await? {
match message {
Message::Authentication(auth) => {
match *auth {
protocol::Authentication::Ok => {
// Do nothing. No password is needed to continue.
}
protocol::Authentication::CleartextPassword => {
protocol::PasswordMessage::Cleartext(password)
.encode(self.stream.buffer_mut());
self.stream.flush().await?;
}
protocol::Authentication::Md5Password { salt } => {
protocol::PasswordMessage::Md5 {
password,
user: username,
salt,
}
.encode(self.stream.buffer_mut());
self.stream.flush().await?;
}
auth => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("requires unimplemented authentication method: {:?}", auth),
)
.into());
}
}
}
Message::BackendKeyData(body) => {
self.process_id = body.process_id();
self.secret_key = body.secret_key();
}
Message::ReadyForQuery(_) => {
// Connection fully established and ready to receive queries.
break;
}
message => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("received unexpected message: {:?}", message),
)
.into());
}
}
}
Ok(())
}
#[async_std::test]
#[ignore]
async fn it_executes() -> crate::Result<()> {
let mut conn = PostgresRawConnection::establish(&database_url()).await?;
// https://www.postgresql.org/docs/devel/protocol-flow.html#id-1.10.5.7.10
pub(super) async fn terminate(mut self) -> crate::Result<()> {
protocol::Terminate.encode(self.stream.buffer_mut());
let affected_rows_from_begin =
RawConnection::execute(&mut conn, "BEGIN", PostgresQueryParameters::new()).await?;
assert_eq!(affected_rows_from_begin, 0);
let affected_rows_from_create_table = RawConnection::execute(
&mut conn,
r#"
CREATE TEMP TABLE sqlx_test_it_executes (
id BIGSERIAL PRIMARY KEY
)
"#,
PostgresQueryParameters::new(),
)
.await?;
assert_eq!(affected_rows_from_create_table, 0);
for _ in 0..5_i32 {
let affected_rows_from_insert = RawConnection::execute(
&mut conn,
"INSERT INTO sqlx_test_it_executes DEFAULT VALUES",
PostgresQueryParameters::new(),
)
.await?;
assert_eq!(affected_rows_from_insert, 1);
}
let affected_rows_from_delete = RawConnection::execute(
&mut conn,
"DELETE FROM sqlx_test_it_executes",
PostgresQueryParameters::new(),
)
.await?;
assert_eq!(affected_rows_from_delete, 5);
let affected_rows_from_rollback =
RawConnection::execute(&mut conn, "ROLLBACK", PostgresQueryParameters::new()).await?;
assert_eq!(affected_rows_from_rollback, 0);
self.stream.flush().await?;
self.stream.stream.shutdown(Shutdown::Both)?;
Ok(())
}
pub(super) fn parse(&mut self, statement: &str, query: &str, params: &PostgresQueryParameters) {
protocol::Parse {
statement,
query,
param_types: &*params.types,
}
.encode(self.stream.buffer_mut());
}
pub(super) fn describe(&mut self, statement: &str) {
protocol::Describe {
kind: protocol::DescribeKind::PreparedStatement,
name: statement,
}
.encode(self.stream.buffer_mut())
}
pub(super) fn bind(&mut self, portal: &str, statement: &str, params: &PostgresQueryParameters) {
protocol::Bind {
portal,
statement,
formats: &[1], // [BINARY]
// TODO: Early error if there is more than i16
values_len: params.types.len() as i16,
values: &*params.buf,
result_formats: &[1], // [BINARY]
}
.encode(self.stream.buffer_mut());
}
pub(super) fn execute(&mut self, portal: &str, limit: i32) {
protocol::Execute { portal, limit }.encode(self.stream.buffer_mut());
}
pub(super) async fn sync(&mut self) -> crate::Result<()> {
protocol::Sync.encode(self.stream.buffer_mut());
self.stream.flush().await?;
Ok(())
}
pub(super) async fn step(&mut self) -> crate::Result<Option<Step>> {
while let Some(message) = self.receive().await? {
match message {
Message::BindComplete
| Message::ParseComplete
| Message::PortalSuspended
| Message::CloseComplete => {}
Message::CommandComplete(body) => {
return Ok(Some(Step::Command(body.affected_rows())));
}
Message::DataRow(body) => {
return Ok(Some(Step::Row(PostgresRow(body))));
}
Message::ReadyForQuery(_) => {
return Ok(None);
}
Message::ParameterDescription(desc) => {
return Ok(Some(Step::ParamDesc(desc)));
}
Message::RowDescription(desc) => {
return Ok(Some(Step::RowDesc(desc)));
}
message => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("received unexpected message: {:?}", message),
)
.into());
}
}
}
// Connection was (unexpectedly) closed
Err(io::Error::from(io::ErrorKind::UnexpectedEof).into())
}
// Wait and return the next message to be received from Postgres.
pub(super) async fn receive(&mut self) -> crate::Result<Option<Message>> {
// Before we start the receive loop
// Flush any pending data from the send buffer
self.stream.flush().await?;
loop {
// Read the message header (id + len)
let mut header = ret_if_none!(self.stream.peek(5).await?);
let id = header.get_u8()?;
let len = (header.get_u32::<NetworkEndian>()? - 4) as usize;
// Read the message body
self.stream.consume(5);
let body = ret_if_none!(self.stream.peek(len).await?);
let message = match id {
b'N' | b'E' => Message::Response(Box::new(protocol::Response::decode(body)?)),
b'D' => Message::DataRow(protocol::DataRow::decode(body)?),
b'S' => {
Message::ParameterStatus(Box::new(protocol::ParameterStatus::decode(body)?))
}
b'Z' => Message::ReadyForQuery(protocol::ReadyForQuery::decode(body)?),
b'R' => Message::Authentication(Box::new(protocol::Authentication::decode(body)?)),
b'K' => Message::BackendKeyData(protocol::BackendKeyData::decode(body)?),
b'C' => Message::CommandComplete(protocol::CommandComplete::decode(body)?),
b'A' => Message::NotificationResponse(Box::new(
protocol::NotificationResponse::decode(body)?,
)),
b'1' => Message::ParseComplete,
b'2' => Message::BindComplete,
b'3' => Message::CloseComplete,
b'n' => Message::NoData,
b's' => Message::PortalSuspended,
b't' => Message::ParameterDescription(Box::new(
protocol::ParameterDescription::decode(body)?,
)),
b'T' => Message::RowDescription(Box::new(protocol::RowDescription::decode(body)?)),
id => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("received unknown message id: {:?}", id),
)
.into());
}
};
self.stream.consume(len);
match message {
Message::ParameterStatus(_body) => {
// TODO: not sure what to do with these yet
}
Message::Response(body) => {
if body.severity().is_error() {
// This is an error, stop the world and bubble as an error
return Err(PostgresDatabaseError(body).into());
} else {
// This is a _warning_
// TODO: Do we *want* to do anything with these
}
}
message => {
return Ok(Some(message));
}
}
}
}
}
#[derive(Debug)]
pub(super) enum Step {
Command(u64),
Row(PostgresRow),
ParamDesc(Box<protocol::ParameterDescription>),
RowDesc(Box<protocol::RowDescription>),
}

View File

@@ -2,7 +2,6 @@ mod backend;
mod connection;
mod error;
mod query;
mod raw;
mod row;
#[cfg(not(feature = "unstable"))]
@@ -14,6 +13,6 @@ pub mod protocol;
pub mod types;
pub use self::{
backend::Postgres, error::PostgresDatabaseError, query::PostgresQueryParameters,
raw::PostgresRawConnection, row::PostgresRow,
connection::Postgres, error::PostgresDatabaseError, query::PostgresQueryParameters,
row::PostgresRow,
};

View File

@@ -1,307 +0,0 @@
use crate::{
io::{Buf, BufStream},
postgres::{
protocol::{self, Decode, Encode, Message},
PostgresDatabaseError, PostgresQueryParameters, PostgresRow,
},
};
use std::net::Shutdown;
use byteorder::NetworkEndian;
use std::{io, net::SocketAddr};
use async_std::net::TcpStream;
pub struct PostgresRawConnection {
stream: BufStream<TcpStream>,
// Process ID of the Backend
process_id: u32,
// Backend-unique key to use to send a cancel query message to the server
secret_key: u32,
}
// [x] 52.2.1. Start-up
// [ ] 52.2.2. Simple Query
// [ ] 52.2.3. Extended Query
// [ ] 52.2.4. Function Call
// [ ] 52.2.5. COPY Operations
// [ ] 52.2.6. Asynchronous Operations
// [ ] 52.2.7. Canceling Requests in Progress
// [x] 52.2.8. Termination
// [ ] 52.2.9. SSL Session Encryption
// [ ] 52.2.10. GSSAPI Session Encryption
impl PostgresRawConnection {
pub(super) async fn new(address: SocketAddr) -> crate::Result<Self> {
let stream = TcpStream::connect(&address).await?;
Ok(Self {
stream: BufStream::new(stream),
process_id: 0,
secret_key: 0,
})
}
// https://www.postgresql.org/docs/devel/protocol-flow.html#id-1.10.5.7.3
pub(super) async fn startup(
&mut self,
username: &str,
password: &str,
database: &str,
) -> crate::Result<()> {
// See this doc for more runtime parameters
// https://www.postgresql.org/docs/12/runtime-config-client.html
let params = &[
("user", username),
("database", database),
// Sets the display format for date and time values,
// as well as the rules for interpreting ambiguous date input values.
("DateStyle", "ISO, MDY"),
// Sets the display format for interval values.
("IntervalStyle", "iso_8601"),
// Sets the time zone for displaying and interpreting time stamps.
("TimeZone", "UTC"),
// Adjust postgres to return percise values for floats
// NOTE: This is default in postgres 12+
("extra_float_digits", "3"),
// Sets the client-side encoding (character set).
("client_encoding", "UTF-8"),
];
protocol::StartupMessage { params }.encode(self.stream.buffer_mut());
self.stream.flush().await?;
while let Some(message) = self.receive().await? {
match message {
Message::Authentication(auth) => {
match *auth {
protocol::Authentication::Ok => {
// Do nothing. No password is needed to continue.
}
protocol::Authentication::CleartextPassword => {
protocol::PasswordMessage::Cleartext(password)
.encode(self.stream.buffer_mut());
self.stream.flush().await?;
}
protocol::Authentication::Md5Password { salt } => {
protocol::PasswordMessage::Md5 {
password,
user: username,
salt,
}
.encode(self.stream.buffer_mut());
self.stream.flush().await?;
}
auth => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("requires unimplemented authentication method: {:?}", auth),
)
.into());
}
}
}
Message::BackendKeyData(body) => {
self.process_id = body.process_id();
self.secret_key = body.secret_key();
}
Message::ReadyForQuery(_) => {
// Connection fully established and ready to receive queries.
break;
}
message => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("received unexpected message: {:?}", message),
)
.into());
}
}
}
Ok(())
}
// https://www.postgresql.org/docs/devel/protocol-flow.html#id-1.10.5.7.10
pub(super) async fn terminate(mut self) -> crate::Result<()> {
protocol::Terminate.encode(self.stream.buffer_mut());
self.stream.flush().await?;
self.stream.stream.shutdown(Shutdown::Both)?;
Ok(())
}
pub(super) fn parse(&mut self, statement: &str, query: &str, params: &PostgresQueryParameters) {
protocol::Parse {
statement,
query,
param_types: &*params.types,
}
.encode(self.stream.buffer_mut());
}
pub(super) fn describe(&mut self, statement: &str) {
protocol::Describe {
kind: protocol::DescribeKind::PreparedStatement,
name: statement,
}
.encode(self.stream.buffer_mut())
}
pub(super) fn bind(&mut self, portal: &str, statement: &str, params: &PostgresQueryParameters) {
protocol::Bind {
portal,
statement,
formats: &[1], // [BINARY]
// TODO: Early error if there is more than i16
values_len: params.types.len() as i16,
values: &*params.buf,
result_formats: &[1], // [BINARY]
}
.encode(self.stream.buffer_mut());
}
pub(super) fn execute(&mut self, portal: &str, limit: i32) {
protocol::Execute { portal, limit }.encode(self.stream.buffer_mut());
}
pub(super) async fn sync(&mut self) -> crate::Result<()> {
protocol::Sync.encode(self.stream.buffer_mut());
self.stream.flush().await?;
Ok(())
}
pub(super) async fn step(&mut self) -> crate::Result<Option<Step>> {
while let Some(message) = self.receive().await? {
match message {
Message::BindComplete
| Message::ParseComplete
| Message::PortalSuspended
| Message::CloseComplete => {}
Message::CommandComplete(body) => {
return Ok(Some(Step::Command(body.affected_rows())));
}
Message::DataRow(body) => {
return Ok(Some(Step::Row(PostgresRow(body))));
}
Message::ReadyForQuery(_) => {
return Ok(None);
}
Message::ParameterDescription(desc) => {
return Ok(Some(Step::ParamDesc(desc)));
}
Message::RowDescription(desc) => {
return Ok(Some(Step::RowDesc(desc)));
}
message => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("received unexpected message: {:?}", message),
)
.into());
}
}
}
// Connection was (unexpectedly) closed
Err(io::Error::from(io::ErrorKind::UnexpectedEof).into())
}
// Wait and return the next message to be received from Postgres.
pub(super) async fn receive(&mut self) -> crate::Result<Option<Message>> {
// Before we start the receive loop
// Flush any pending data from the send buffer
self.stream.flush().await?;
loop {
// Read the message header (id + len)
let mut header = ret_if_none!(self.stream.peek(5).await?);
let id = header.get_u8()?;
let len = (header.get_u32::<NetworkEndian>()? - 4) as usize;
// Read the message body
self.stream.consume(5);
let body = ret_if_none!(self.stream.peek(len).await?);
let message = match id {
b'N' | b'E' => Message::Response(Box::new(protocol::Response::decode(body)?)),
b'D' => Message::DataRow(protocol::DataRow::decode(body)?),
b'S' => {
Message::ParameterStatus(Box::new(protocol::ParameterStatus::decode(body)?))
}
b'Z' => Message::ReadyForQuery(protocol::ReadyForQuery::decode(body)?),
b'R' => Message::Authentication(Box::new(protocol::Authentication::decode(body)?)),
b'K' => Message::BackendKeyData(protocol::BackendKeyData::decode(body)?),
b'C' => Message::CommandComplete(protocol::CommandComplete::decode(body)?),
b'A' => Message::NotificationResponse(Box::new(
protocol::NotificationResponse::decode(body)?,
)),
b'1' => Message::ParseComplete,
b'2' => Message::BindComplete,
b'3' => Message::CloseComplete,
b'n' => Message::NoData,
b's' => Message::PortalSuspended,
b't' => Message::ParameterDescription(Box::new(
protocol::ParameterDescription::decode(body)?,
)),
b'T' => Message::RowDescription(Box::new(protocol::RowDescription::decode(body)?)),
id => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("received unknown message id: {:?}", id),
)
.into());
}
};
self.stream.consume(len);
match message {
Message::ParameterStatus(_body) => {
// TODO: not sure what to do with these yet
}
Message::Response(body) => {
if body.severity().is_error() {
// This is an error, stop the world and bubble as an error
return Err(PostgresDatabaseError(body).into());
} else {
// This is a _warning_
// TODO: Do we *want* to do anything with these
}
}
message => {
return Ok(Some(message));
}
}
}
}
}
#[derive(Debug)]
pub(super) enum Step {
Command(u64),
Row(PostgresRow),
ParamDesc(Box<protocol::ParameterDescription>),
RowDesc(Box<protocol::RowDescription>),
}

View File

@@ -88,22 +88,22 @@ async fn process_sql(input: MacroInput) -> Result<TokenStream> {
match db_url.scheme() {
#[cfg(feature = "postgres")]
"postgresql" => {
"postgresql" | "postgres" => {
process_sql_with(
input,
sqlx::Connection::<sqlx::Postgres>::establish(db_url.as_str())
sqlx::Connection::<sqlx::Postgres>::open(db_url.as_str())
.await
.map_err(|e| format!("failed to connect to database: {}", e))?,
)
.await
}
#[cfg(feature = "mariadb")]
"mysql" => {
"mysql" | "mariadb" => {
process_sql_with(
input,
sqlx::Connection::<sqlx::MariaDb>::establish(db_url.as_str())
.await
.map_err(|e| format!("failed to connect to database: {}", e))?,
sqlx::Connection::<sqlx::MariaDb>::open(db_url.as_str())
.await
.map_err(|e| format!("failed to connect to database: {}", e))?,
)
.await
}

View File

@@ -1,7 +1,7 @@
#[async_std::test]
async fn test_sqlx_macro() -> sqlx::Result<()> {
let mut conn =
sqlx::Connection::<sqlx::Postgres>::establish("postgres://postgres@127.0.0.1/sqlx_test")
sqlx::Connection::<sqlx::Postgres>::open("postgres://postgres@127.0.0.1/sqlx_test")
.await?;
let uuid: sqlx::types::Uuid = "256ba9c8-0048-11ea-b0f0-8f04859d047e".parse().unwrap();
let accounts = sqlx::query!("SELECT * from accounts where id != $1", None)