mirror of
https://github.com/launchbadge/sqlx.git
synced 2026-04-30 05:54:20 +00:00
Remove the RawConnection concept and fold into Backend
This commit is contained in:
@@ -1,20 +1,61 @@
|
||||
use crate::{connection::RawConnection, query::QueryParameters, row::Row, types::HasTypeMetadata};
|
||||
use crate::describe::Describe;
|
||||
use crate::{query::QueryParameters, row::Row, types::HasTypeMetadata};
|
||||
use async_trait::async_trait;
|
||||
use futures_core::stream::BoxStream;
|
||||
|
||||
/// A database backend.
|
||||
///
|
||||
/// This trait represents the concept of a backend (e.g. "MySQL" vs "SQLite").
|
||||
pub trait Backend: HasTypeMetadata + Sized {
|
||||
/// Represents a connection to the database and further provides auxillary but
|
||||
/// important related traits as associated types.
|
||||
///
|
||||
/// This trait is not intended to be used directly.
|
||||
/// Instead [sqlx::Connection] or [sqlx::Pool] should be used instead,
|
||||
/// which provide concurrent access and typed retrieval of results.
|
||||
#[async_trait]
|
||||
pub trait Backend: HasTypeMetadata + Send + Sync + Sized {
|
||||
/// The concrete `QueryParameters` implementation for this backend.
|
||||
type QueryParameters: QueryParameters<Backend = Self>;
|
||||
|
||||
/// The concrete `RawConnection` implementation for this backend.
|
||||
type RawConnection: RawConnection<Backend = Self>;
|
||||
|
||||
/// The concrete `Row` implementation for this backend. This type is returned
|
||||
/// from methods in the `RawConnection`.
|
||||
/// The concrete `Row` implementation for this backend.
|
||||
type Row: Row<Backend = Self>;
|
||||
|
||||
/// The identifier for tables; in Postgres this is an `oid` while
|
||||
/// in MariaDB/MySQL this is the qualified name of the table.
|
||||
type TableIdent;
|
||||
|
||||
/// Establish a new connection to the database server.
|
||||
async fn open(url: &str) -> crate::Result<Self>
|
||||
where
|
||||
Self: Sized;
|
||||
|
||||
/// Release resources for this database connection immediately.
|
||||
///
|
||||
/// This method is not required to be called. A database server will
|
||||
/// eventually notice and clean up not fully closed connections.
|
||||
async fn close(mut self) -> crate::Result<()>;
|
||||
|
||||
async fn ping(&mut self) -> crate::Result<()> {
|
||||
// TODO: Does this need to be specialized for any database backends?
|
||||
let _ = self
|
||||
.execute("SELECT 1", Self::QueryParameters::new())
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn describe(&mut self, query: &str) -> crate::Result<Describe<Self>>;
|
||||
|
||||
async fn execute(&mut self, query: &str, params: Self::QueryParameters) -> crate::Result<u64>;
|
||||
|
||||
fn fetch(
|
||||
&mut self,
|
||||
query: &str,
|
||||
params: Self::QueryParameters,
|
||||
) -> BoxStream<'_, crate::Result<Self::Row>>;
|
||||
|
||||
async fn fetch_optional(
|
||||
&mut self,
|
||||
query: &str,
|
||||
params: Self::QueryParameters,
|
||||
) -> crate::Result<Option<Self::Row>>;
|
||||
}
|
||||
|
||||
@@ -4,77 +4,12 @@ use crate::{
|
||||
error::Error,
|
||||
executor::Executor,
|
||||
pool::{Live, SharedPool},
|
||||
query::{IntoQueryParameters, QueryParameters},
|
||||
query::IntoQueryParameters,
|
||||
row::FromSqlRow,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use crossbeam_queue::SegQueue;
|
||||
use crossbeam_utils::atomic::AtomicCell;
|
||||
use futures_channel::oneshot::{channel, Sender};
|
||||
use futures_core::{future::BoxFuture, stream::BoxStream};
|
||||
use futures_util::stream::StreamExt;
|
||||
use std::{
|
||||
sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
Arc,
|
||||
},
|
||||
time::Instant,
|
||||
};
|
||||
|
||||
/// A connection to the database.
|
||||
///
|
||||
/// This trait is not intended to be used directly. Instead [sqlx::Connection] or [sqlx::Pool] should be used instead, which provide
|
||||
/// concurrent access and typed retrieval of results.
|
||||
#[async_trait]
|
||||
pub trait RawConnection: Send + Sync {
|
||||
// The database backend this type connects to.
|
||||
type Backend: Backend;
|
||||
|
||||
/// Establish a new connection to the database server.
|
||||
async fn establish(url: &str) -> crate::Result<Self>
|
||||
where
|
||||
Self: Sized;
|
||||
|
||||
/// Release resources for this database connection immediately.
|
||||
///
|
||||
/// This method is not required to be called. A database server will eventually notice
|
||||
/// and clean up not fully closed connections.
|
||||
///
|
||||
/// It is safe to close an already closed connection.
|
||||
async fn close(mut self) -> crate::Result<()>;
|
||||
|
||||
/// Verifies a connection to the database is still alive.
|
||||
async fn ping(&mut self) -> crate::Result<()> {
|
||||
let _ = self
|
||||
.execute(
|
||||
"SELECT 1",
|
||||
<<Self::Backend as Backend>::QueryParameters>::new(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn execute(
|
||||
&mut self,
|
||||
query: &str,
|
||||
params: <Self::Backend as Backend>::QueryParameters,
|
||||
) -> crate::Result<u64>;
|
||||
|
||||
fn fetch(
|
||||
&mut self,
|
||||
query: &str,
|
||||
params: <Self::Backend as Backend>::QueryParameters,
|
||||
) -> BoxStream<'_, crate::Result<<Self::Backend as Backend>::Row>>;
|
||||
|
||||
async fn fetch_optional(
|
||||
&mut self,
|
||||
query: &str,
|
||||
params: <Self::Backend as Backend>::QueryParameters,
|
||||
) -> crate::Result<Option<<Self::Backend as Backend>::Row>>;
|
||||
|
||||
async fn describe(&mut self, query: &str) -> crate::Result<Describe<Self::Backend>>;
|
||||
}
|
||||
use std::{sync::Arc, time::Instant};
|
||||
|
||||
pub struct Connection<DB>
|
||||
where
|
||||
@@ -95,8 +30,8 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn establish(url: &str) -> crate::Result<Self> {
|
||||
let raw = <DB as Backend>::RawConnection::establish(url).await?;
|
||||
pub async fn open(url: &str) -> crate::Result<Self> {
|
||||
let raw = DB::open(url).await?;
|
||||
let live = Live {
|
||||
raw,
|
||||
since: Instant::now(),
|
||||
|
||||
@@ -13,7 +13,9 @@ pub struct Describe<DB: Backend> {
|
||||
}
|
||||
|
||||
impl<DB: Backend> fmt::Debug for Describe<DB>
|
||||
where <DB as HasTypeMetadata>::TypeId: fmt::Debug, ResultField<DB>: fmt::Debug
|
||||
where
|
||||
<DB as HasTypeMetadata>::TypeId: fmt::Debug,
|
||||
ResultField<DB>: fmt::Debug,
|
||||
{
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
f.debug_struct("Describe")
|
||||
@@ -31,7 +33,9 @@ pub struct ResultField<DB: Backend> {
|
||||
}
|
||||
|
||||
impl<DB: Backend> fmt::Debug for ResultField<DB>
|
||||
where <DB as Backend>::TableIdent: fmt::Debug, <DB as HasTypeMetadata>::TypeId: fmt::Debug
|
||||
where
|
||||
<DB as Backend>::TableIdent: fmt::Debug,
|
||||
<DB as HasTypeMetadata>::TypeId: fmt::Debug,
|
||||
{
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
f.debug_struct("ResultField")
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
use async_std::io::{
|
||||
prelude::{ReadExt, WriteExt},
|
||||
Read, Write,
|
||||
};
|
||||
use bytes::{BufMut, BytesMut};
|
||||
use std::io;
|
||||
use async_std::io::{Read, Write, prelude::{ReadExt, WriteExt}};
|
||||
use async_std::future::poll_fn;
|
||||
|
||||
pub struct BufStream<S> {
|
||||
pub(crate) stream: S,
|
||||
|
||||
@@ -1,13 +1,87 @@
|
||||
use super::{MariaDb, MariaDbQueryParameters, MariaDbRow};
|
||||
use crate::backend::Backend;
|
||||
use crate::describe::{Describe, ResultField};
|
||||
use crate::mariadb::protocol::ColumnDefinitionPacket;
|
||||
use async_trait::async_trait;
|
||||
use futures_core::stream::BoxStream;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MariaDb;
|
||||
|
||||
#[async_trait]
|
||||
impl Backend for MariaDb {
|
||||
type QueryParameters = super::MariaDbQueryParameters;
|
||||
type RawConnection = super::MariaDbRawConnection;
|
||||
type Row = super::MariaDbRow;
|
||||
type QueryParameters = MariaDbQueryParameters;
|
||||
type Row = MariaDbRow;
|
||||
type TableIdent = String;
|
||||
|
||||
async fn open(url: &str) -> crate::Result<Self>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
MariaDb::open(url).await
|
||||
}
|
||||
|
||||
async fn close(mut self) -> crate::Result<()> {
|
||||
self.close().await
|
||||
}
|
||||
|
||||
async fn ping(&mut self) -> crate::Result<()> {
|
||||
self.ping().await
|
||||
}
|
||||
|
||||
async fn execute(&mut self, query: &str, params: MariaDbQueryParameters) -> crate::Result<u64> {
|
||||
// Write prepare statement to buffer
|
||||
self.start_sequence();
|
||||
let prepare_ok = self.send_prepare(query).await?;
|
||||
|
||||
let affected = self.execute(prepare_ok.statement_id, params).await?;
|
||||
|
||||
Ok(affected)
|
||||
}
|
||||
|
||||
fn fetch(
|
||||
&mut self,
|
||||
_query: &str,
|
||||
_params: MariaDbQueryParameters,
|
||||
) -> BoxStream<'_, crate::Result<MariaDbRow>> {
|
||||
unimplemented!();
|
||||
}
|
||||
|
||||
async fn fetch_optional(
|
||||
&mut self,
|
||||
_query: &str,
|
||||
_params: MariaDbQueryParameters,
|
||||
) -> crate::Result<Option<Self::Row>> {
|
||||
unimplemented!();
|
||||
}
|
||||
|
||||
async fn describe(&mut self, query: &str) -> crate::Result<Describe<MariaDb>> {
|
||||
let prepare_ok = self.send_prepare(query).await?;
|
||||
|
||||
let mut param_types = Vec::with_capacity(prepare_ok.params as usize);
|
||||
|
||||
for _ in 0..prepare_ok.params {
|
||||
let param = ColumnDefinitionPacket::decode(self.receive().await?)?;
|
||||
param_types.push(param.field_type.0);
|
||||
}
|
||||
|
||||
self.check_eof().await?;
|
||||
|
||||
let mut columns = Vec::with_capacity(prepare_ok.columns as usize);
|
||||
|
||||
for _ in 0..prepare_ok.columns {
|
||||
let column = ColumnDefinitionPacket::decode(self.receive().await?)?;
|
||||
columns.push(ResultField {
|
||||
name: column.column_alias.or(column.column),
|
||||
table_id: column.table_alias.or(column.table),
|
||||
type_id: column.field_type.0,
|
||||
})
|
||||
}
|
||||
|
||||
self.check_eof().await?;
|
||||
|
||||
Ok(Describe {
|
||||
param_types,
|
||||
result_fields: columns,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl_from_sql_row_tuples_for_backend!(MariaDb);
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
use super::establish;
|
||||
use crate::{
|
||||
connection::RawConnection,
|
||||
describe::{Describe, ResultField},
|
||||
error::DatabaseError,
|
||||
io::{Buf, BufMut, BufStream},
|
||||
mariadb::{
|
||||
protocol::{
|
||||
@@ -10,31 +7,27 @@ use crate::{
|
||||
ComStmtExecute, ComStmtPrepare, ComStmtPrepareOk, Encode, EofPacket, ErrPacket,
|
||||
OkPacket, ResultRow, StmtExecFlag,
|
||||
},
|
||||
MariaDb, MariaDbQueryParameters, MariaDbRow,
|
||||
MariaDbQueryParameters,
|
||||
},
|
||||
Backend, Error, Result,
|
||||
Error, Result,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use async_std::net::TcpStream;
|
||||
use byteorder::{ByteOrder, LittleEndian};
|
||||
use futures_core::{future::BoxFuture, stream::BoxStream};
|
||||
use futures_util::stream::{self, StreamExt};
|
||||
use std::{
|
||||
future::Future,
|
||||
io,
|
||||
net::{IpAddr, SocketAddr},
|
||||
};
|
||||
use async_std::net::TcpStream;
|
||||
use url::Url;
|
||||
|
||||
pub struct MariaDbRawConnection {
|
||||
pub struct MariaDb {
|
||||
pub(crate) stream: BufStream<TcpStream>,
|
||||
pub(crate) rbuf: Vec<u8>,
|
||||
pub(crate) capabilities: Capabilities,
|
||||
next_seq_no: u8,
|
||||
}
|
||||
|
||||
impl MariaDbRawConnection {
|
||||
async fn establish(url: &str) -> Result<Self> {
|
||||
impl MariaDb {
|
||||
pub async fn open(url: &str) -> Result<Self> {
|
||||
// TODO: Handle errors
|
||||
let url = Url::parse(url).unwrap();
|
||||
|
||||
@@ -110,7 +103,7 @@ impl MariaDbRawConnection {
|
||||
Ok(Some(&self.rbuf[..len]))
|
||||
}
|
||||
|
||||
fn start_sequence(&mut self) {
|
||||
pub(super) fn start_sequence(&mut self) {
|
||||
// At the start of a command sequence we reset our understanding
|
||||
// of [next_seq_no]. In a sequence our initial command must be 0, followed
|
||||
// by the server response that is 1, then our response to that response (if any),
|
||||
@@ -147,7 +140,7 @@ impl MariaDbRawConnection {
|
||||
// to terminate immediately
|
||||
pub(crate) async fn receive_ok_or_err(&mut self) -> Result<OkPacket> {
|
||||
let capabilities = self.capabilities;
|
||||
let mut buf = self.receive().await?;
|
||||
let buf = self.receive().await?;
|
||||
Ok(match buf[0] {
|
||||
0xfe | 0x00 => OkPacket::decode(buf, capabilities)?,
|
||||
|
||||
@@ -175,7 +168,7 @@ impl MariaDbRawConnection {
|
||||
})
|
||||
}
|
||||
|
||||
async fn check_eof(&mut self) -> Result<()> {
|
||||
pub(super) async fn check_eof(&mut self) -> Result<()> {
|
||||
if !self
|
||||
.capabilities
|
||||
.contains(Capabilities::CLIENT_DEPRECATE_EOF)
|
||||
@@ -186,7 +179,10 @@ impl MariaDbRawConnection {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_prepare<'c>(&'c mut self, statement: &'c str) -> Result<ComStmtPrepareOk> {
|
||||
pub(super) async fn send_prepare<'c>(
|
||||
&'c mut self,
|
||||
statement: &'c str,
|
||||
) -> Result<ComStmtPrepareOk> {
|
||||
self.stream.flush().await?;
|
||||
|
||||
self.start_sequence();
|
||||
@@ -204,7 +200,11 @@ impl MariaDbRawConnection {
|
||||
ComStmtPrepareOk::decode(packet).map_err(Into::into)
|
||||
}
|
||||
|
||||
async fn execute(&mut self, statement_id: u32, params: MariaDbQueryParameters) -> Result<u64> {
|
||||
pub(super) async fn execute(
|
||||
&mut self,
|
||||
statement_id: u32,
|
||||
_params: MariaDbQueryParameters,
|
||||
) -> Result<u64> {
|
||||
// TODO: EXECUTE(READ_ONLY) => FETCH instead of EXECUTE(NO)
|
||||
|
||||
// SEND ================
|
||||
@@ -279,130 +279,3 @@ impl MariaDbRawConnection {
|
||||
Ok(rows)
|
||||
}
|
||||
}
|
||||
|
||||
enum ExecResult {
|
||||
NoRows(OkPacket),
|
||||
Rows(Vec<ColumnDefinitionPacket>),
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl RawConnection for MariaDbRawConnection {
|
||||
type Backend = MariaDb;
|
||||
|
||||
async fn establish(url: &str) -> crate::Result<Self>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
MariaDbRawConnection::establish(url).await
|
||||
}
|
||||
|
||||
async fn close(mut self) -> crate::Result<()> {
|
||||
self.close().await
|
||||
}
|
||||
|
||||
async fn ping(&mut self) -> crate::Result<()> {
|
||||
self.ping().await
|
||||
}
|
||||
|
||||
async fn execute(&mut self, query: &str, params: MariaDbQueryParameters) -> crate::Result<u64> {
|
||||
// Write prepare statement to buffer
|
||||
self.start_sequence();
|
||||
let prepare_ok = self.send_prepare(query).await?;
|
||||
|
||||
let affected = self.execute(prepare_ok.statement_id, params).await?;
|
||||
|
||||
Ok(affected)
|
||||
}
|
||||
|
||||
fn fetch(
|
||||
&mut self,
|
||||
query: &str,
|
||||
params: MariaDbQueryParameters,
|
||||
) -> BoxStream<'_, Result<MariaDbRow>> {
|
||||
unimplemented!();
|
||||
}
|
||||
|
||||
async fn fetch_optional(
|
||||
&mut self,
|
||||
query: &str,
|
||||
params: MariaDbQueryParameters,
|
||||
) -> crate::Result<Option<<Self::Backend as Backend>::Row>> {
|
||||
unimplemented!();
|
||||
}
|
||||
|
||||
async fn describe(&mut self, query: &str) -> crate::Result<Describe<MariaDb>> {
|
||||
let prepare_ok = self.send_prepare(query).await?;
|
||||
|
||||
let mut param_types = Vec::with_capacity(prepare_ok.params as usize);
|
||||
|
||||
for _ in 0..prepare_ok.params {
|
||||
let param = ColumnDefinitionPacket::decode(self.receive().await?)?;
|
||||
param_types.push(param.field_type.0);
|
||||
}
|
||||
|
||||
self.check_eof().await?;
|
||||
|
||||
let mut columns = Vec::with_capacity(prepare_ok.columns as usize);
|
||||
|
||||
for _ in 0..prepare_ok.columns {
|
||||
let column = ColumnDefinitionPacket::decode(self.receive().await?)?;
|
||||
columns.push(ResultField {
|
||||
name: column.column_alias.or(column.column),
|
||||
table_id: column.table_alias.or(column.table),
|
||||
type_id: column.field_type.0,
|
||||
})
|
||||
}
|
||||
|
||||
self.check_eof().await?;
|
||||
|
||||
Ok(Describe {
|
||||
param_types,
|
||||
result_fields: columns,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use crate::{query::QueryParameters, Error, Pool};
|
||||
|
||||
#[async_std::test]
|
||||
async fn it_can_connect() -> Result<()> {
|
||||
MariaDbRawConnection::establish("mariadb://root@127.0.0.1:3306/test").await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[async_std::test]
|
||||
async fn it_fails_to_connect_with_bad_username() -> Result<()> {
|
||||
match MariaDbRawConnection::establish("mariadb://roote@127.0.0.1:3306/test").await {
|
||||
Ok(_) => panic!("Somehow connected to database with incorrect username"),
|
||||
Err(_) => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
#[async_std::test]
|
||||
async fn it_can_ping() -> Result<()> {
|
||||
let mut conn =
|
||||
MariaDbRawConnection::establish("mariadb://root@127.0.0.1:3306/test").await?;
|
||||
conn.ping().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[async_std::test]
|
||||
async fn it_can_describe() -> Result<()> {
|
||||
let mut conn =
|
||||
MariaDbRawConnection::establish("mysql://sqlx_user@127.0.0.1:3306/sqlx_test").await?;
|
||||
let describe = conn.describe("SELECT id from accounts where id = ?").await?;
|
||||
|
||||
dbg!(describe);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[async_std::test]
|
||||
async fn it_can_create_mariadb_pool() -> Result<()> {
|
||||
let pool: Pool<MariaDb> = Pool::new("mariadb://root@127.0.0.1:3306/test").await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
use crate::{
|
||||
mariadb::{
|
||||
connection::MariaDbRawConnection,
|
||||
protocol::{Capabilities, Encode, HandshakeResponsePacket, InitialHandshakePacket},
|
||||
connection::MariaDb,
|
||||
protocol::{Capabilities, HandshakeResponsePacket, InitialHandshakePacket},
|
||||
},
|
||||
Result,
|
||||
};
|
||||
use url::Url;
|
||||
|
||||
pub(crate) async fn establish(conn: &mut MariaDbRawConnection, url: &Url) -> Result<()> {
|
||||
pub(crate) async fn establish(conn: &mut MariaDb, url: &Url) -> Result<()> {
|
||||
let initial = InitialHandshakePacket::decode(conn.receive().await?)?;
|
||||
|
||||
// TODO: Capabilities::SECURE_CONNECTION
|
||||
@@ -17,7 +17,7 @@ pub(crate) async fn establish(conn: &mut MariaDbRawConnection, url: &Url) -> Res
|
||||
// TODO: Capabilities::TRANSACTIONS
|
||||
// TODO: Capabilities::CLIENT_DEPRECATE_EOF
|
||||
// TODO?: Capabilities::CLIENT_SESSION_TRACK
|
||||
let mut capabilities = Capabilities::CLIENT_PROTOCOL_41 | Capabilities::CONNECT_WITH_DB;
|
||||
let capabilities = Capabilities::CLIENT_PROTOCOL_41 | Capabilities::CONNECT_WITH_DB;
|
||||
|
||||
let response = HandshakeResponsePacket {
|
||||
// TODO: Find a good value for [max_packet_size]
|
||||
|
||||
@@ -168,6 +168,7 @@ mod tests {
|
||||
assert_eq!(&buf[..], b"\xFC\xFE\x00");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_int_lenenc_ff() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_uint_lenenc::<LittleEndian, _>(0xFF as u64);
|
||||
|
||||
@@ -8,7 +8,4 @@ mod query;
|
||||
mod row;
|
||||
pub mod types;
|
||||
|
||||
pub use self::{
|
||||
backend::MariaDb, connection::MariaDbRawConnection, query::MariaDbQueryParameters,
|
||||
row::MariaDbRow,
|
||||
};
|
||||
pub use self::{connection::MariaDb, query::MariaDbQueryParameters, row::MariaDbRow};
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
use crate::{
|
||||
backend::Backend,
|
||||
connection::{Connection, RawConnection},
|
||||
error::Error,
|
||||
executor::Executor,
|
||||
query::IntoQueryParameters,
|
||||
row::FromSqlRow,
|
||||
backend::Backend, connection::Connection, error::Error, executor::Executor,
|
||||
query::IntoQueryParameters, row::FromSqlRow,
|
||||
};
|
||||
use crossbeam_queue::{ArrayQueue, SegQueue};
|
||||
use futures_channel::oneshot;
|
||||
@@ -238,7 +234,7 @@ where
|
||||
if self.size.compare_and_swap(size, size + 1, Ordering::AcqRel) == size {
|
||||
// Open a new connection and return directly
|
||||
|
||||
let raw = <DB as Backend>::RawConnection::establish(&self.url).await?;
|
||||
let raw = DB::open(&self.url).await?;
|
||||
let live = Live {
|
||||
raw,
|
||||
since: Instant::now(),
|
||||
@@ -404,7 +400,7 @@ pub(crate) struct Live<DB>
|
||||
where
|
||||
DB: Backend,
|
||||
{
|
||||
pub(crate) raw: DB::RawConnection,
|
||||
pub(crate) raw: DB,
|
||||
#[allow(unused)]
|
||||
pub(crate) since: Instant,
|
||||
}
|
||||
|
||||
@@ -1,13 +1,149 @@
|
||||
use super::connection::Step;
|
||||
use super::Postgres;
|
||||
use super::PostgresQueryParameters;
|
||||
use super::PostgresRow;
|
||||
use crate::backend::Backend;
|
||||
use crate::describe::{Describe, ResultField};
|
||||
use crate::query::QueryParameters;
|
||||
use crate::url::Url;
|
||||
use async_trait::async_trait;
|
||||
use futures_core::stream::BoxStream;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Postgres;
|
||||
|
||||
#[async_trait]
|
||||
impl Backend for Postgres {
|
||||
type QueryParameters = super::PostgresQueryParameters;
|
||||
type RawConnection = super::PostgresRawConnection;
|
||||
type Row = super::PostgresRow;
|
||||
type QueryParameters = PostgresQueryParameters;
|
||||
|
||||
type Row = PostgresRow;
|
||||
|
||||
type TableIdent = u32;
|
||||
|
||||
async fn open(url: &str) -> crate::Result<Self> {
|
||||
let url = Url::parse(url)?;
|
||||
let address = url.resolve(5432);
|
||||
let mut conn = Self::new(address).await?;
|
||||
|
||||
conn.startup(
|
||||
url.username(),
|
||||
url.password().unwrap_or_default(),
|
||||
url.database(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(conn)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
async fn close(mut self) -> crate::Result<()> {
|
||||
self.terminate().await
|
||||
}
|
||||
|
||||
async fn execute(
|
||||
&mut self,
|
||||
query: &str,
|
||||
params: PostgresQueryParameters,
|
||||
) -> crate::Result<u64> {
|
||||
self.parse("", query, ¶ms);
|
||||
self.bind("", "", ¶ms);
|
||||
self.execute("", 1);
|
||||
self.sync().await?;
|
||||
|
||||
let mut affected = 0;
|
||||
|
||||
while let Some(step) = self.step().await? {
|
||||
if let Step::Command(cnt) = step {
|
||||
affected = cnt;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(affected)
|
||||
}
|
||||
|
||||
fn fetch(
|
||||
&mut self,
|
||||
query: &str,
|
||||
params: PostgresQueryParameters,
|
||||
) -> BoxStream<'_, crate::Result<PostgresRow>> {
|
||||
self.parse("", query, ¶ms);
|
||||
self.bind("", "", ¶ms);
|
||||
self.execute("", 0);
|
||||
|
||||
Box::pin(async_stream::try_stream! {
|
||||
self.sync().await?;
|
||||
|
||||
while let Some(step) = self.step().await? {
|
||||
if let Step::Row(row) = step {
|
||||
yield row;
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async fn fetch_optional(
|
||||
&mut self,
|
||||
query: &str,
|
||||
params: PostgresQueryParameters,
|
||||
) -> crate::Result<Option<PostgresRow>> {
|
||||
self.parse("", query, ¶ms);
|
||||
self.bind("", "", ¶ms);
|
||||
self.execute("", 2);
|
||||
self.sync().await?;
|
||||
|
||||
let mut row: Option<PostgresRow> = None;
|
||||
|
||||
while let Some(step) = self.step().await? {
|
||||
if let Step::Row(r) = step {
|
||||
if row.is_some() {
|
||||
return Err(crate::Error::FoundMoreThanOne);
|
||||
}
|
||||
|
||||
row = Some(r);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(row)
|
||||
}
|
||||
|
||||
async fn describe(&mut self, body: &str) -> crate::Result<Describe<Postgres>> {
|
||||
self.parse("", body, &PostgresQueryParameters::new());
|
||||
self.describe("");
|
||||
self.sync().await?;
|
||||
|
||||
let param_desc = loop {
|
||||
let step = self
|
||||
.step()
|
||||
.await?
|
||||
.ok_or(invalid_data!("did not receive ParameterDescription"));
|
||||
|
||||
if let Step::ParamDesc(desc) = step? {
|
||||
break desc;
|
||||
}
|
||||
};
|
||||
|
||||
let row_desc = loop {
|
||||
let step = self
|
||||
.step()
|
||||
.await?
|
||||
.ok_or(invalid_data!("did not receive RowDescription"));
|
||||
|
||||
if let Step::RowDesc(desc) = step? {
|
||||
break desc;
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Describe {
|
||||
param_types: param_desc.ids.into_vec(),
|
||||
result_fields: row_desc
|
||||
.fields
|
||||
.into_vec()
|
||||
.into_iter()
|
||||
.map(|field| ResultField {
|
||||
name: Some(field.name),
|
||||
table_id: Some(field.table_id),
|
||||
type_id: field.type_id,
|
||||
})
|
||||
.collect(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl_from_sql_row_tuples_for_backend!(Postgres);
|
||||
|
||||
@@ -1,226 +1,307 @@
|
||||
use super::{Postgres, PostgresQueryParameters, PostgresRawConnection, PostgresRow};
|
||||
use crate::{
|
||||
connection::RawConnection,
|
||||
describe::{Describe, ResultField},
|
||||
postgres::raw::Step,
|
||||
query::QueryParameters,
|
||||
url::Url,
|
||||
io::{Buf, BufStream},
|
||||
postgres::{
|
||||
protocol::{self, Decode, Encode, Message},
|
||||
PostgresDatabaseError, PostgresQueryParameters, PostgresRow,
|
||||
},
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use futures_core::stream::BoxStream;
|
||||
use async_std::net::TcpStream;
|
||||
use byteorder::NetworkEndian;
|
||||
use std::net::Shutdown;
|
||||
use std::{io, net::SocketAddr};
|
||||
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
pub struct Postgres {
|
||||
stream: BufStream<TcpStream>,
|
||||
|
||||
use crate::postgres::{protocol::Message, PostgresDatabaseError};
|
||||
use std::hash::Hasher;
|
||||
// Process ID of the Backend
|
||||
process_id: u32,
|
||||
|
||||
#[async_trait]
|
||||
impl RawConnection for PostgresRawConnection {
|
||||
type Backend = Postgres;
|
||||
|
||||
async fn establish(url: &str) -> crate::Result<Self> {
|
||||
let url = Url::parse(url)?;
|
||||
let address = url.resolve(5432);
|
||||
let mut conn = Self::new(address).await?;
|
||||
|
||||
conn.startup(
|
||||
url.username(),
|
||||
url.password().unwrap_or_default(),
|
||||
url.database(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(conn)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
async fn close(mut self) -> crate::Result<()> {
|
||||
self.terminate().await
|
||||
}
|
||||
|
||||
async fn execute(
|
||||
&mut self,
|
||||
query: &str,
|
||||
params: PostgresQueryParameters,
|
||||
) -> crate::Result<u64> {
|
||||
self.parse("", query, ¶ms);
|
||||
self.bind("", "", ¶ms);
|
||||
self.execute("", 1);
|
||||
self.sync().await?;
|
||||
|
||||
let mut affected = 0;
|
||||
|
||||
while let Some(step) = self.step().await? {
|
||||
if let Step::Command(cnt) = step {
|
||||
affected = cnt;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(affected)
|
||||
}
|
||||
|
||||
fn fetch(
|
||||
&mut self,
|
||||
query: &str,
|
||||
params: PostgresQueryParameters,
|
||||
) -> BoxStream<'_, crate::Result<PostgresRow>> {
|
||||
self.parse("", query, ¶ms);
|
||||
self.bind("", "", ¶ms);
|
||||
self.execute("", 0);
|
||||
|
||||
Box::pin(async_stream::try_stream! {
|
||||
self.sync().await?;
|
||||
|
||||
while let Some(step) = self.step().await? {
|
||||
if let Step::Row(row) = step {
|
||||
yield row;
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async fn fetch_optional(
|
||||
&mut self,
|
||||
query: &str,
|
||||
params: PostgresQueryParameters,
|
||||
) -> crate::Result<Option<PostgresRow>> {
|
||||
self.parse("", query, ¶ms);
|
||||
self.bind("", "", ¶ms);
|
||||
self.execute("", 2);
|
||||
self.sync().await?;
|
||||
|
||||
let mut row: Option<PostgresRow> = None;
|
||||
|
||||
while let Some(step) = self.step().await? {
|
||||
if let Step::Row(r) = step {
|
||||
if row.is_some() {
|
||||
return Err(crate::Error::FoundMoreThanOne);
|
||||
}
|
||||
|
||||
row = Some(r);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(row)
|
||||
}
|
||||
|
||||
async fn describe(&mut self, body: &str) -> crate::Result<Describe<Postgres>> {
|
||||
self.parse("", body, &PostgresQueryParameters::new());
|
||||
self.describe("");
|
||||
self.sync().await?;
|
||||
|
||||
let param_desc = loop {
|
||||
let step = self
|
||||
.step()
|
||||
.await?
|
||||
.ok_or(invalid_data!("did not receive ParameterDescription"));
|
||||
|
||||
if let Step::ParamDesc(desc) = step? {
|
||||
break desc;
|
||||
}
|
||||
};
|
||||
|
||||
let row_desc = loop {
|
||||
let step = self
|
||||
.step()
|
||||
.await?
|
||||
.ok_or(invalid_data!("did not receive RowDescription"));
|
||||
|
||||
if let Step::RowDesc(desc) = step? {
|
||||
break desc;
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Describe {
|
||||
param_types: param_desc.ids.into_vec(),
|
||||
result_fields: row_desc
|
||||
.fields
|
||||
.into_vec()
|
||||
.into_iter()
|
||||
.map(|field| ResultField {
|
||||
name: Some(field.name),
|
||||
table_id: Some(field.table_id),
|
||||
type_id: field.type_id,
|
||||
})
|
||||
.collect(),
|
||||
})
|
||||
}
|
||||
// Backend-unique key to use to send a cancel query message to the server
|
||||
secret_key: u32,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::query::QueryParameters;
|
||||
use std::env;
|
||||
// [x] 52.2.1. Start-up
|
||||
// [ ] 52.2.2. Simple Query
|
||||
// [ ] 52.2.3. Extended Query
|
||||
// [ ] 52.2.4. Function Call
|
||||
// [ ] 52.2.5. COPY Operations
|
||||
// [ ] 52.2.6. Asynchronous Operations
|
||||
// [ ] 52.2.7. Canceling Requests in Progress
|
||||
// [x] 52.2.8. Termination
|
||||
// [ ] 52.2.9. SSL Session Encryption
|
||||
// [ ] 52.2.10. GSSAPI Session Encryption
|
||||
|
||||
fn database_url() -> String {
|
||||
env::var("POSTGRES_DATABASE_URL")
|
||||
.or_else(|_| env::var("DATABASE_URL"))
|
||||
.unwrap()
|
||||
impl Postgres {
|
||||
pub(super) async fn new(address: SocketAddr) -> crate::Result<Self> {
|
||||
let stream = TcpStream::connect(&address).await?;
|
||||
|
||||
Ok(Self {
|
||||
stream: BufStream::new(stream),
|
||||
process_id: 0,
|
||||
secret_key: 0,
|
||||
})
|
||||
}
|
||||
|
||||
#[async_std::test]
|
||||
#[ignore]
|
||||
async fn it_establishes() -> crate::Result<()> {
|
||||
let mut conn = PostgresRawConnection::establish(&database_url()).await?;
|
||||
// https://www.postgresql.org/docs/devel/protocol-flow.html#id-1.10.5.7.3
|
||||
pub(super) async fn startup(
|
||||
&mut self,
|
||||
username: &str,
|
||||
password: &str,
|
||||
database: &str,
|
||||
) -> crate::Result<()> {
|
||||
// See this doc for more runtime parameters
|
||||
// https://www.postgresql.org/docs/12/runtime-config-client.html
|
||||
let params = &[
|
||||
("user", username),
|
||||
("database", database),
|
||||
// Sets the display format for date and time values,
|
||||
// as well as the rules for interpreting ambiguous date input values.
|
||||
("DateStyle", "ISO, MDY"),
|
||||
// Sets the display format for interval values.
|
||||
("IntervalStyle", "iso_8601"),
|
||||
// Sets the time zone for displaying and interpreting time stamps.
|
||||
("TimeZone", "UTC"),
|
||||
// Adjust postgres to return percise values for floats
|
||||
// NOTE: This is default in postgres 12+
|
||||
("extra_float_digits", "3"),
|
||||
// Sets the client-side encoding (character set).
|
||||
("client_encoding", "UTF-8"),
|
||||
];
|
||||
|
||||
// After establish, run PING to ensure that it was established correctly
|
||||
conn.ping().await?;
|
||||
protocol::StartupMessage { params }.encode(self.stream.buffer_mut());
|
||||
self.stream.flush().await?;
|
||||
|
||||
// Then explicitly close the connection
|
||||
conn.close().await?;
|
||||
while let Some(message) = self.receive().await? {
|
||||
match message {
|
||||
Message::Authentication(auth) => {
|
||||
match *auth {
|
||||
protocol::Authentication::Ok => {
|
||||
// Do nothing. No password is needed to continue.
|
||||
}
|
||||
|
||||
protocol::Authentication::CleartextPassword => {
|
||||
protocol::PasswordMessage::Cleartext(password)
|
||||
.encode(self.stream.buffer_mut());
|
||||
|
||||
self.stream.flush().await?;
|
||||
}
|
||||
|
||||
protocol::Authentication::Md5Password { salt } => {
|
||||
protocol::PasswordMessage::Md5 {
|
||||
password,
|
||||
user: username,
|
||||
salt,
|
||||
}
|
||||
.encode(self.stream.buffer_mut());
|
||||
|
||||
self.stream.flush().await?;
|
||||
}
|
||||
|
||||
auth => {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("requires unimplemented authentication method: {:?}", auth),
|
||||
)
|
||||
.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Message::BackendKeyData(body) => {
|
||||
self.process_id = body.process_id();
|
||||
self.secret_key = body.secret_key();
|
||||
}
|
||||
|
||||
Message::ReadyForQuery(_) => {
|
||||
// Connection fully established and ready to receive queries.
|
||||
break;
|
||||
}
|
||||
|
||||
message => {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("received unexpected message: {:?}", message),
|
||||
)
|
||||
.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[async_std::test]
|
||||
#[ignore]
|
||||
async fn it_executes() -> crate::Result<()> {
|
||||
let mut conn = PostgresRawConnection::establish(&database_url()).await?;
|
||||
// https://www.postgresql.org/docs/devel/protocol-flow.html#id-1.10.5.7.10
|
||||
pub(super) async fn terminate(mut self) -> crate::Result<()> {
|
||||
protocol::Terminate.encode(self.stream.buffer_mut());
|
||||
|
||||
let affected_rows_from_begin =
|
||||
RawConnection::execute(&mut conn, "BEGIN", PostgresQueryParameters::new()).await?;
|
||||
|
||||
assert_eq!(affected_rows_from_begin, 0);
|
||||
|
||||
let affected_rows_from_create_table = RawConnection::execute(
|
||||
&mut conn,
|
||||
r#"
|
||||
CREATE TEMP TABLE sqlx_test_it_executes (
|
||||
id BIGSERIAL PRIMARY KEY
|
||||
)
|
||||
"#,
|
||||
PostgresQueryParameters::new(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
assert_eq!(affected_rows_from_create_table, 0);
|
||||
|
||||
for _ in 0..5_i32 {
|
||||
let affected_rows_from_insert = RawConnection::execute(
|
||||
&mut conn,
|
||||
"INSERT INTO sqlx_test_it_executes DEFAULT VALUES",
|
||||
PostgresQueryParameters::new(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
assert_eq!(affected_rows_from_insert, 1);
|
||||
}
|
||||
|
||||
let affected_rows_from_delete = RawConnection::execute(
|
||||
&mut conn,
|
||||
"DELETE FROM sqlx_test_it_executes",
|
||||
PostgresQueryParameters::new(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
assert_eq!(affected_rows_from_delete, 5);
|
||||
|
||||
let affected_rows_from_rollback =
|
||||
RawConnection::execute(&mut conn, "ROLLBACK", PostgresQueryParameters::new()).await?;
|
||||
|
||||
assert_eq!(affected_rows_from_rollback, 0);
|
||||
self.stream.flush().await?;
|
||||
self.stream.stream.shutdown(Shutdown::Both)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(super) fn parse(&mut self, statement: &str, query: &str, params: &PostgresQueryParameters) {
|
||||
protocol::Parse {
|
||||
statement,
|
||||
query,
|
||||
param_types: &*params.types,
|
||||
}
|
||||
.encode(self.stream.buffer_mut());
|
||||
}
|
||||
|
||||
pub(super) fn describe(&mut self, statement: &str) {
|
||||
protocol::Describe {
|
||||
kind: protocol::DescribeKind::PreparedStatement,
|
||||
name: statement,
|
||||
}
|
||||
.encode(self.stream.buffer_mut())
|
||||
}
|
||||
|
||||
pub(super) fn bind(&mut self, portal: &str, statement: &str, params: &PostgresQueryParameters) {
|
||||
protocol::Bind {
|
||||
portal,
|
||||
statement,
|
||||
formats: &[1], // [BINARY]
|
||||
// TODO: Early error if there is more than i16
|
||||
values_len: params.types.len() as i16,
|
||||
values: &*params.buf,
|
||||
result_formats: &[1], // [BINARY]
|
||||
}
|
||||
.encode(self.stream.buffer_mut());
|
||||
}
|
||||
|
||||
pub(super) fn execute(&mut self, portal: &str, limit: i32) {
|
||||
protocol::Execute { portal, limit }.encode(self.stream.buffer_mut());
|
||||
}
|
||||
|
||||
pub(super) async fn sync(&mut self) -> crate::Result<()> {
|
||||
protocol::Sync.encode(self.stream.buffer_mut());
|
||||
|
||||
self.stream.flush().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(super) async fn step(&mut self) -> crate::Result<Option<Step>> {
|
||||
while let Some(message) = self.receive().await? {
|
||||
match message {
|
||||
Message::BindComplete
|
||||
| Message::ParseComplete
|
||||
| Message::PortalSuspended
|
||||
| Message::CloseComplete => {}
|
||||
|
||||
Message::CommandComplete(body) => {
|
||||
return Ok(Some(Step::Command(body.affected_rows())));
|
||||
}
|
||||
|
||||
Message::DataRow(body) => {
|
||||
return Ok(Some(Step::Row(PostgresRow(body))));
|
||||
}
|
||||
|
||||
Message::ReadyForQuery(_) => {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
Message::ParameterDescription(desc) => {
|
||||
return Ok(Some(Step::ParamDesc(desc)));
|
||||
}
|
||||
|
||||
Message::RowDescription(desc) => {
|
||||
return Ok(Some(Step::RowDesc(desc)));
|
||||
}
|
||||
|
||||
message => {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("received unexpected message: {:?}", message),
|
||||
)
|
||||
.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Connection was (unexpectedly) closed
|
||||
Err(io::Error::from(io::ErrorKind::UnexpectedEof).into())
|
||||
}
|
||||
|
||||
// Wait and return the next message to be received from Postgres.
|
||||
pub(super) async fn receive(&mut self) -> crate::Result<Option<Message>> {
|
||||
// Before we start the receive loop
|
||||
// Flush any pending data from the send buffer
|
||||
self.stream.flush().await?;
|
||||
|
||||
loop {
|
||||
// Read the message header (id + len)
|
||||
let mut header = ret_if_none!(self.stream.peek(5).await?);
|
||||
|
||||
let id = header.get_u8()?;
|
||||
let len = (header.get_u32::<NetworkEndian>()? - 4) as usize;
|
||||
|
||||
// Read the message body
|
||||
self.stream.consume(5);
|
||||
let body = ret_if_none!(self.stream.peek(len).await?);
|
||||
|
||||
let message = match id {
|
||||
b'N' | b'E' => Message::Response(Box::new(protocol::Response::decode(body)?)),
|
||||
b'D' => Message::DataRow(protocol::DataRow::decode(body)?),
|
||||
b'S' => {
|
||||
Message::ParameterStatus(Box::new(protocol::ParameterStatus::decode(body)?))
|
||||
}
|
||||
b'Z' => Message::ReadyForQuery(protocol::ReadyForQuery::decode(body)?),
|
||||
b'R' => Message::Authentication(Box::new(protocol::Authentication::decode(body)?)),
|
||||
b'K' => Message::BackendKeyData(protocol::BackendKeyData::decode(body)?),
|
||||
b'C' => Message::CommandComplete(protocol::CommandComplete::decode(body)?),
|
||||
b'A' => Message::NotificationResponse(Box::new(
|
||||
protocol::NotificationResponse::decode(body)?,
|
||||
)),
|
||||
b'1' => Message::ParseComplete,
|
||||
b'2' => Message::BindComplete,
|
||||
b'3' => Message::CloseComplete,
|
||||
b'n' => Message::NoData,
|
||||
b's' => Message::PortalSuspended,
|
||||
b't' => Message::ParameterDescription(Box::new(
|
||||
protocol::ParameterDescription::decode(body)?,
|
||||
)),
|
||||
b'T' => Message::RowDescription(Box::new(protocol::RowDescription::decode(body)?)),
|
||||
|
||||
id => {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("received unknown message id: {:?}", id),
|
||||
)
|
||||
.into());
|
||||
}
|
||||
};
|
||||
|
||||
self.stream.consume(len);
|
||||
|
||||
match message {
|
||||
Message::ParameterStatus(_body) => {
|
||||
// TODO: not sure what to do with these yet
|
||||
}
|
||||
|
||||
Message::Response(body) => {
|
||||
if body.severity().is_error() {
|
||||
// This is an error, stop the world and bubble as an error
|
||||
return Err(PostgresDatabaseError(body).into());
|
||||
} else {
|
||||
// This is a _warning_
|
||||
// TODO: Do we *want* to do anything with these
|
||||
}
|
||||
}
|
||||
|
||||
message => {
|
||||
return Ok(Some(message));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(super) enum Step {
|
||||
Command(u64),
|
||||
Row(PostgresRow),
|
||||
ParamDesc(Box<protocol::ParameterDescription>),
|
||||
RowDesc(Box<protocol::RowDescription>),
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ mod backend;
|
||||
mod connection;
|
||||
mod error;
|
||||
mod query;
|
||||
mod raw;
|
||||
mod row;
|
||||
|
||||
#[cfg(not(feature = "unstable"))]
|
||||
@@ -14,6 +13,6 @@ pub mod protocol;
|
||||
pub mod types;
|
||||
|
||||
pub use self::{
|
||||
backend::Postgres, error::PostgresDatabaseError, query::PostgresQueryParameters,
|
||||
raw::PostgresRawConnection, row::PostgresRow,
|
||||
connection::Postgres, error::PostgresDatabaseError, query::PostgresQueryParameters,
|
||||
row::PostgresRow,
|
||||
};
|
||||
|
||||
@@ -1,307 +0,0 @@
|
||||
use crate::{
|
||||
io::{Buf, BufStream},
|
||||
postgres::{
|
||||
protocol::{self, Decode, Encode, Message},
|
||||
PostgresDatabaseError, PostgresQueryParameters, PostgresRow,
|
||||
},
|
||||
};
|
||||
use std::net::Shutdown;
|
||||
use byteorder::NetworkEndian;
|
||||
use std::{io, net::SocketAddr};
|
||||
use async_std::net::TcpStream;
|
||||
|
||||
pub struct PostgresRawConnection {
|
||||
stream: BufStream<TcpStream>,
|
||||
|
||||
// Process ID of the Backend
|
||||
process_id: u32,
|
||||
|
||||
// Backend-unique key to use to send a cancel query message to the server
|
||||
secret_key: u32,
|
||||
}
|
||||
|
||||
// [x] 52.2.1. Start-up
|
||||
// [ ] 52.2.2. Simple Query
|
||||
// [ ] 52.2.3. Extended Query
|
||||
// [ ] 52.2.4. Function Call
|
||||
// [ ] 52.2.5. COPY Operations
|
||||
// [ ] 52.2.6. Asynchronous Operations
|
||||
// [ ] 52.2.7. Canceling Requests in Progress
|
||||
// [x] 52.2.8. Termination
|
||||
// [ ] 52.2.9. SSL Session Encryption
|
||||
// [ ] 52.2.10. GSSAPI Session Encryption
|
||||
|
||||
impl PostgresRawConnection {
|
||||
pub(super) async fn new(address: SocketAddr) -> crate::Result<Self> {
|
||||
let stream = TcpStream::connect(&address).await?;
|
||||
|
||||
Ok(Self {
|
||||
stream: BufStream::new(stream),
|
||||
process_id: 0,
|
||||
secret_key: 0,
|
||||
})
|
||||
}
|
||||
|
||||
// https://www.postgresql.org/docs/devel/protocol-flow.html#id-1.10.5.7.3
|
||||
pub(super) async fn startup(
|
||||
&mut self,
|
||||
username: &str,
|
||||
password: &str,
|
||||
database: &str,
|
||||
) -> crate::Result<()> {
|
||||
// See this doc for more runtime parameters
|
||||
// https://www.postgresql.org/docs/12/runtime-config-client.html
|
||||
let params = &[
|
||||
("user", username),
|
||||
("database", database),
|
||||
// Sets the display format for date and time values,
|
||||
// as well as the rules for interpreting ambiguous date input values.
|
||||
("DateStyle", "ISO, MDY"),
|
||||
// Sets the display format for interval values.
|
||||
("IntervalStyle", "iso_8601"),
|
||||
// Sets the time zone for displaying and interpreting time stamps.
|
||||
("TimeZone", "UTC"),
|
||||
// Adjust postgres to return percise values for floats
|
||||
// NOTE: This is default in postgres 12+
|
||||
("extra_float_digits", "3"),
|
||||
// Sets the client-side encoding (character set).
|
||||
("client_encoding", "UTF-8"),
|
||||
];
|
||||
|
||||
protocol::StartupMessage { params }.encode(self.stream.buffer_mut());
|
||||
self.stream.flush().await?;
|
||||
|
||||
while let Some(message) = self.receive().await? {
|
||||
match message {
|
||||
Message::Authentication(auth) => {
|
||||
match *auth {
|
||||
protocol::Authentication::Ok => {
|
||||
// Do nothing. No password is needed to continue.
|
||||
}
|
||||
|
||||
protocol::Authentication::CleartextPassword => {
|
||||
protocol::PasswordMessage::Cleartext(password)
|
||||
.encode(self.stream.buffer_mut());
|
||||
|
||||
self.stream.flush().await?;
|
||||
}
|
||||
|
||||
protocol::Authentication::Md5Password { salt } => {
|
||||
protocol::PasswordMessage::Md5 {
|
||||
password,
|
||||
user: username,
|
||||
salt,
|
||||
}
|
||||
.encode(self.stream.buffer_mut());
|
||||
|
||||
self.stream.flush().await?;
|
||||
}
|
||||
|
||||
auth => {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("requires unimplemented authentication method: {:?}", auth),
|
||||
)
|
||||
.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Message::BackendKeyData(body) => {
|
||||
self.process_id = body.process_id();
|
||||
self.secret_key = body.secret_key();
|
||||
}
|
||||
|
||||
Message::ReadyForQuery(_) => {
|
||||
// Connection fully established and ready to receive queries.
|
||||
break;
|
||||
}
|
||||
|
||||
message => {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("received unexpected message: {:?}", message),
|
||||
)
|
||||
.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// https://www.postgresql.org/docs/devel/protocol-flow.html#id-1.10.5.7.10
|
||||
pub(super) async fn terminate(mut self) -> crate::Result<()> {
|
||||
protocol::Terminate.encode(self.stream.buffer_mut());
|
||||
|
||||
self.stream.flush().await?;
|
||||
self.stream.stream.shutdown(Shutdown::Both)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(super) fn parse(&mut self, statement: &str, query: &str, params: &PostgresQueryParameters) {
|
||||
protocol::Parse {
|
||||
statement,
|
||||
query,
|
||||
param_types: &*params.types,
|
||||
}
|
||||
.encode(self.stream.buffer_mut());
|
||||
}
|
||||
|
||||
pub(super) fn describe(&mut self, statement: &str) {
|
||||
protocol::Describe {
|
||||
kind: protocol::DescribeKind::PreparedStatement,
|
||||
name: statement,
|
||||
}
|
||||
.encode(self.stream.buffer_mut())
|
||||
}
|
||||
|
||||
pub(super) fn bind(&mut self, portal: &str, statement: &str, params: &PostgresQueryParameters) {
|
||||
protocol::Bind {
|
||||
portal,
|
||||
statement,
|
||||
formats: &[1], // [BINARY]
|
||||
// TODO: Early error if there is more than i16
|
||||
values_len: params.types.len() as i16,
|
||||
values: &*params.buf,
|
||||
result_formats: &[1], // [BINARY]
|
||||
}
|
||||
.encode(self.stream.buffer_mut());
|
||||
}
|
||||
|
||||
pub(super) fn execute(&mut self, portal: &str, limit: i32) {
|
||||
protocol::Execute { portal, limit }.encode(self.stream.buffer_mut());
|
||||
}
|
||||
|
||||
pub(super) async fn sync(&mut self) -> crate::Result<()> {
|
||||
protocol::Sync.encode(self.stream.buffer_mut());
|
||||
|
||||
self.stream.flush().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(super) async fn step(&mut self) -> crate::Result<Option<Step>> {
|
||||
while let Some(message) = self.receive().await? {
|
||||
match message {
|
||||
Message::BindComplete
|
||||
| Message::ParseComplete
|
||||
| Message::PortalSuspended
|
||||
| Message::CloseComplete => {}
|
||||
|
||||
Message::CommandComplete(body) => {
|
||||
return Ok(Some(Step::Command(body.affected_rows())));
|
||||
}
|
||||
|
||||
Message::DataRow(body) => {
|
||||
return Ok(Some(Step::Row(PostgresRow(body))));
|
||||
}
|
||||
|
||||
Message::ReadyForQuery(_) => {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
Message::ParameterDescription(desc) => {
|
||||
return Ok(Some(Step::ParamDesc(desc)));
|
||||
}
|
||||
|
||||
Message::RowDescription(desc) => {
|
||||
return Ok(Some(Step::RowDesc(desc)));
|
||||
}
|
||||
|
||||
message => {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("received unexpected message: {:?}", message),
|
||||
)
|
||||
.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Connection was (unexpectedly) closed
|
||||
Err(io::Error::from(io::ErrorKind::UnexpectedEof).into())
|
||||
}
|
||||
|
||||
// Wait and return the next message to be received from Postgres.
|
||||
pub(super) async fn receive(&mut self) -> crate::Result<Option<Message>> {
|
||||
// Before we start the receive loop
|
||||
// Flush any pending data from the send buffer
|
||||
self.stream.flush().await?;
|
||||
|
||||
loop {
|
||||
// Read the message header (id + len)
|
||||
let mut header = ret_if_none!(self.stream.peek(5).await?);
|
||||
|
||||
let id = header.get_u8()?;
|
||||
let len = (header.get_u32::<NetworkEndian>()? - 4) as usize;
|
||||
|
||||
// Read the message body
|
||||
self.stream.consume(5);
|
||||
let body = ret_if_none!(self.stream.peek(len).await?);
|
||||
|
||||
let message = match id {
|
||||
b'N' | b'E' => Message::Response(Box::new(protocol::Response::decode(body)?)),
|
||||
b'D' => Message::DataRow(protocol::DataRow::decode(body)?),
|
||||
b'S' => {
|
||||
Message::ParameterStatus(Box::new(protocol::ParameterStatus::decode(body)?))
|
||||
}
|
||||
b'Z' => Message::ReadyForQuery(protocol::ReadyForQuery::decode(body)?),
|
||||
b'R' => Message::Authentication(Box::new(protocol::Authentication::decode(body)?)),
|
||||
b'K' => Message::BackendKeyData(protocol::BackendKeyData::decode(body)?),
|
||||
b'C' => Message::CommandComplete(protocol::CommandComplete::decode(body)?),
|
||||
b'A' => Message::NotificationResponse(Box::new(
|
||||
protocol::NotificationResponse::decode(body)?,
|
||||
)),
|
||||
b'1' => Message::ParseComplete,
|
||||
b'2' => Message::BindComplete,
|
||||
b'3' => Message::CloseComplete,
|
||||
b'n' => Message::NoData,
|
||||
b's' => Message::PortalSuspended,
|
||||
b't' => Message::ParameterDescription(Box::new(
|
||||
protocol::ParameterDescription::decode(body)?,
|
||||
)),
|
||||
b'T' => Message::RowDescription(Box::new(protocol::RowDescription::decode(body)?)),
|
||||
|
||||
id => {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("received unknown message id: {:?}", id),
|
||||
)
|
||||
.into());
|
||||
}
|
||||
};
|
||||
|
||||
self.stream.consume(len);
|
||||
|
||||
match message {
|
||||
Message::ParameterStatus(_body) => {
|
||||
// TODO: not sure what to do with these yet
|
||||
}
|
||||
|
||||
Message::Response(body) => {
|
||||
if body.severity().is_error() {
|
||||
// This is an error, stop the world and bubble as an error
|
||||
return Err(PostgresDatabaseError(body).into());
|
||||
} else {
|
||||
// This is a _warning_
|
||||
// TODO: Do we *want* to do anything with these
|
||||
}
|
||||
}
|
||||
|
||||
message => {
|
||||
return Ok(Some(message));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(super) enum Step {
|
||||
Command(u64),
|
||||
Row(PostgresRow),
|
||||
ParamDesc(Box<protocol::ParameterDescription>),
|
||||
RowDesc(Box<protocol::RowDescription>),
|
||||
}
|
||||
|
||||
@@ -88,22 +88,22 @@ async fn process_sql(input: MacroInput) -> Result<TokenStream> {
|
||||
|
||||
match db_url.scheme() {
|
||||
#[cfg(feature = "postgres")]
|
||||
"postgresql" => {
|
||||
"postgresql" | "postgres" => {
|
||||
process_sql_with(
|
||||
input,
|
||||
sqlx::Connection::<sqlx::Postgres>::establish(db_url.as_str())
|
||||
sqlx::Connection::<sqlx::Postgres>::open(db_url.as_str())
|
||||
.await
|
||||
.map_err(|e| format!("failed to connect to database: {}", e))?,
|
||||
)
|
||||
.await
|
||||
}
|
||||
#[cfg(feature = "mariadb")]
|
||||
"mysql" => {
|
||||
"mysql" | "mariadb" => {
|
||||
process_sql_with(
|
||||
input,
|
||||
sqlx::Connection::<sqlx::MariaDb>::establish(db_url.as_str())
|
||||
.await
|
||||
.map_err(|e| format!("failed to connect to database: {}", e))?,
|
||||
sqlx::Connection::<sqlx::MariaDb>::open(db_url.as_str())
|
||||
.await
|
||||
.map_err(|e| format!("failed to connect to database: {}", e))?,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#[async_std::test]
|
||||
async fn test_sqlx_macro() -> sqlx::Result<()> {
|
||||
let mut conn =
|
||||
sqlx::Connection::<sqlx::Postgres>::establish("postgres://postgres@127.0.0.1/sqlx_test")
|
||||
sqlx::Connection::<sqlx::Postgres>::open("postgres://postgres@127.0.0.1/sqlx_test")
|
||||
.await?;
|
||||
let uuid: sqlx::types::Uuid = "256ba9c8-0048-11ea-b0f0-8f04859d047e".parse().unwrap();
|
||||
let accounts = sqlx::query!("SELECT * from accounts where id != $1", None)
|
||||
|
||||
Reference in New Issue
Block a user