mirror of
https://github.com/launchbadge/sqlx.git
synced 2025-12-29 21:00:54 +00:00
Clean up Postgres Connection and start (finally) bubbling database errors
This commit is contained in:
parent
1b7e7c2729
commit
d26f72e1e6
@ -33,6 +33,7 @@ url = "2.1.0"
|
||||
|
||||
[dev-dependencies]
|
||||
matches = "0.1.8"
|
||||
tokio = { version = "=0.2.0-alpha.2", default-features = false, features = [ "rt-full" ] }
|
||||
|
||||
[profile.release]
|
||||
lto = true
|
||||
|
||||
32
src/error.rs
32
src/error.rs
@ -1,9 +1,11 @@
|
||||
use std::{
|
||||
error::Error as StdError,
|
||||
fmt::{self, Display},
|
||||
fmt::{self, Debug, Display},
|
||||
io,
|
||||
};
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Error {
|
||||
/// Error communicating with the database backend.
|
||||
@ -22,7 +24,7 @@ pub enum Error {
|
||||
Io(io::Error),
|
||||
|
||||
/// An error was returned by the database backend.
|
||||
Database(DbError),
|
||||
Database(Box<dyn DbError + Send + Sync>),
|
||||
|
||||
/// No rows were returned by a query expected to return at least one row.
|
||||
NotFound,
|
||||
@ -49,21 +51,19 @@ impl From<io::Error> for Error {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Define a RawError type for the database backend for forwarding error information
|
||||
impl<T> From<T> for Error
|
||||
where
|
||||
T: 'static + DbError,
|
||||
{
|
||||
#[inline]
|
||||
fn from(err: T) -> Self {
|
||||
Error::Database(Box::new(err))
|
||||
}
|
||||
}
|
||||
|
||||
/// An error that was returned by the database backend.
|
||||
#[derive(Debug)]
|
||||
pub struct DbError {
|
||||
message: String,
|
||||
}
|
||||
pub trait DbError: Debug + Send + Sync {
|
||||
fn message(&self) -> &str;
|
||||
|
||||
impl DbError {
|
||||
pub(crate) fn new(message: String) -> Self {
|
||||
Self { message }
|
||||
}
|
||||
|
||||
/// The primary human-readable error message.
|
||||
pub fn message(&self) -> &str {
|
||||
&self.message
|
||||
}
|
||||
// TODO: Expose more error properties
|
||||
}
|
||||
|
||||
@ -32,7 +32,7 @@ pub mod types;
|
||||
|
||||
pub use self::{
|
||||
connection::Connection,
|
||||
error::Error,
|
||||
error::{Error, Result},
|
||||
executor::Executor,
|
||||
pool::Pool,
|
||||
sql::{query, SqlQuery},
|
||||
|
||||
372
src/postgres/connection.rs
Normal file
372
src/postgres/connection.rs
Normal file
@ -0,0 +1,372 @@
|
||||
use super::{
|
||||
protocol::{self, Decode, Encode, Message, Terminate},
|
||||
Postgres, PostgresError, PostgresQueryParameters, PostgresRow,
|
||||
};
|
||||
use crate::{
|
||||
connection::RawConnection,
|
||||
error::Error,
|
||||
io::{Buf, BufStream},
|
||||
query::QueryParameters,
|
||||
url::Url,
|
||||
};
|
||||
use byteorder::NetworkEndian;
|
||||
use futures_core::{future::BoxFuture, stream::BoxStream};
|
||||
use std::{
|
||||
io,
|
||||
net::{IpAddr, Shutdown, SocketAddr},
|
||||
sync::atomic::{AtomicU64, Ordering},
|
||||
};
|
||||
use tokio::{
|
||||
io::{AsyncReadExt, AsyncWriteExt},
|
||||
net::TcpStream,
|
||||
};
|
||||
|
||||
pub struct PostgresRawConnection {
|
||||
stream: BufStream<TcpStream>,
|
||||
|
||||
// Process ID of the Backend
|
||||
process_id: u32,
|
||||
|
||||
// Backend-unique key to use to send a cancel query message to the server
|
||||
secret_key: u32,
|
||||
|
||||
// Statement ID counter
|
||||
next_statement_id: AtomicU64,
|
||||
|
||||
// Portal ID counter
|
||||
next_portal_id: AtomicU64,
|
||||
}
|
||||
|
||||
impl PostgresRawConnection {
|
||||
async fn establish(url: &str) -> crate::Result<Self> {
|
||||
let url = Url::parse(url)?;
|
||||
let stream = TcpStream::connect(&url.resolve(5432)).await?;
|
||||
|
||||
let mut conn = Self {
|
||||
stream: BufStream::new(stream),
|
||||
process_id: 0,
|
||||
secret_key: 0,
|
||||
next_statement_id: AtomicU64::new(0),
|
||||
next_portal_id: AtomicU64::new(0),
|
||||
};
|
||||
|
||||
let user = url.username();
|
||||
let password = url.password().unwrap_or("");
|
||||
let database = url.database();
|
||||
|
||||
// See this doc for more runtime parameters
|
||||
// https://www.postgresql.org/docs/12/runtime-config-client.html
|
||||
let params = &[
|
||||
// FIXME: ConnectOptions user and database need to be required parameters and error
|
||||
// before they get here
|
||||
("user", user),
|
||||
("database", database),
|
||||
// Sets the display format for date and time values,
|
||||
// as well as the rules for interpreting ambiguous date input values.
|
||||
("DateStyle", "ISO, MDY"),
|
||||
// Sets the display format for interval values.
|
||||
("IntervalStyle", "iso_8601"),
|
||||
// Sets the time zone for displaying and interpreting time stamps.
|
||||
("TimeZone", "UTC"),
|
||||
// Adjust postgres to return percise values for floats
|
||||
// NOTE: This is default in postgres 12+
|
||||
("extra_float_digits", "3"),
|
||||
// Sets the client-side encoding (character set).
|
||||
("client_encoding", "UTF-8"),
|
||||
];
|
||||
|
||||
conn.write(protocol::StartupMessage { params });
|
||||
conn.stream.flush().await?;
|
||||
|
||||
while let Some(message) = conn.receive().await? {
|
||||
match message {
|
||||
Message::Authentication(auth) => {
|
||||
match *auth {
|
||||
protocol::Authentication::Ok => {
|
||||
// Do nothing. No password is needed to continue.
|
||||
}
|
||||
|
||||
protocol::Authentication::CleartextPassword => {
|
||||
conn.write(protocol::PasswordMessage::Cleartext(password));
|
||||
|
||||
conn.stream.flush().await?;
|
||||
}
|
||||
|
||||
protocol::Authentication::Md5Password { salt } => {
|
||||
conn.write(protocol::PasswordMessage::Md5 {
|
||||
password,
|
||||
user,
|
||||
salt,
|
||||
});
|
||||
|
||||
conn.stream.flush().await?;
|
||||
}
|
||||
|
||||
auth => {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("requires unimplemented authentication method: {:?}", auth),
|
||||
)
|
||||
.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Message::BackendKeyData(body) => {
|
||||
conn.process_id = body.process_id();
|
||||
conn.secret_key = body.secret_key();
|
||||
}
|
||||
|
||||
Message::ReadyForQuery(_) => {
|
||||
break;
|
||||
}
|
||||
|
||||
message => {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("received unexpected message: {:?}", message),
|
||||
)
|
||||
.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(conn)
|
||||
}
|
||||
|
||||
async fn finalize(&mut self) -> crate::Result<()> {
|
||||
self.write(Terminate);
|
||||
self.stream.flush().await?;
|
||||
self.stream.close().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Wait and return the next message to be received from Postgres.
|
||||
async fn receive(&mut self) -> crate::Result<Option<Message>> {
|
||||
// Before we start the receive loop
|
||||
// Flush any pending data from the send buffer
|
||||
self.stream.flush().await?;
|
||||
|
||||
loop {
|
||||
// Read the message header (id + len)
|
||||
let mut header = ret_if_none!(self.stream.peek(5).await?);
|
||||
|
||||
let id = header.get_u8()?;
|
||||
let len = (header.get_u32::<NetworkEndian>()? - 4) as usize;
|
||||
|
||||
// Read the message body
|
||||
self.stream.consume(5);
|
||||
let body = ret_if_none!(self.stream.peek(len).await?);
|
||||
|
||||
let message = match id {
|
||||
b'N' | b'E' => Message::Response(Box::new(protocol::Response::decode(body)?)),
|
||||
b'D' => Message::DataRow(protocol::DataRow::decode(body)?),
|
||||
b'S' => {
|
||||
Message::ParameterStatus(Box::new(protocol::ParameterStatus::decode(body)?))
|
||||
}
|
||||
b'Z' => Message::ReadyForQuery(protocol::ReadyForQuery::decode(body)?),
|
||||
b'R' => Message::Authentication(Box::new(protocol::Authentication::decode(body)?)),
|
||||
b'K' => Message::BackendKeyData(protocol::BackendKeyData::decode(body)?),
|
||||
b'C' => Message::CommandComplete(protocol::CommandComplete::decode(body)?),
|
||||
b'A' => Message::NotificationResponse(Box::new(
|
||||
protocol::NotificationResponse::decode(body)?,
|
||||
)),
|
||||
b'1' => Message::ParseComplete,
|
||||
b'2' => Message::BindComplete,
|
||||
b'3' => Message::CloseComplete,
|
||||
b'n' => Message::NoData,
|
||||
b's' => Message::PortalSuspended,
|
||||
b't' => Message::ParameterDescription(Box::new(
|
||||
protocol::ParameterDescription::decode(body)?,
|
||||
)),
|
||||
|
||||
id => {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("received unexpected message id: {:?}", id),
|
||||
)
|
||||
.into());
|
||||
}
|
||||
};
|
||||
|
||||
self.stream.consume(len);
|
||||
|
||||
match message {
|
||||
Message::ParameterStatus(_body) => {
|
||||
// TODO: not sure what to do with these yet
|
||||
}
|
||||
|
||||
Message::Response(body) => {
|
||||
if body.severity().is_error() {
|
||||
// This is an error, stop the world and bubble as an error
|
||||
return Err(PostgresError(body).into());
|
||||
} else {
|
||||
// This is a _warning_
|
||||
// TODO: Do we *want* to do anything with these
|
||||
}
|
||||
}
|
||||
|
||||
message => {
|
||||
return Ok(Some(message));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn write(&mut self, message: impl Encode) {
|
||||
message.encode(self.stream.buffer_mut());
|
||||
}
|
||||
|
||||
fn execute(&mut self, query: &str, params: PostgresQueryParameters, limit: i32) {
|
||||
self.write(protocol::Parse {
|
||||
portal: "",
|
||||
query,
|
||||
param_types: &*params.types,
|
||||
});
|
||||
|
||||
self.write(protocol::Bind {
|
||||
portal: "",
|
||||
statement: "",
|
||||
formats: &[1], // [BINARY]
|
||||
// TODO: Early error if there is more than i16
|
||||
values_len: params.types.len() as i16,
|
||||
values: &*params.buf,
|
||||
result_formats: &[1], // [BINARY]
|
||||
});
|
||||
|
||||
// TODO: Make limit be 1 for fetch_optional
|
||||
self.write(protocol::Execute { portal: "", limit });
|
||||
|
||||
self.write(protocol::Sync);
|
||||
}
|
||||
|
||||
// Ask for the next Row in the stream
|
||||
async fn step(&mut self) -> crate::Result<Option<Step>> {
|
||||
while let Some(message) = self.receive().await? {
|
||||
match message {
|
||||
Message::BindComplete
|
||||
| Message::ParseComplete
|
||||
| Message::PortalSuspended
|
||||
| Message::CloseComplete => {}
|
||||
|
||||
Message::CommandComplete(body) => {
|
||||
return Ok(Some(Step::Command(body.affected_rows())));
|
||||
}
|
||||
|
||||
Message::DataRow(body) => {
|
||||
return Ok(Some(Step::Row(PostgresRow(body))));
|
||||
}
|
||||
|
||||
Message::ReadyForQuery(_) => {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
message => {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("received unexpected message: {:?}", message),
|
||||
)
|
||||
.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Connection was (unexpectedly) closed
|
||||
Err(io::Error::from(io::ErrorKind::UnexpectedEof).into())
|
||||
}
|
||||
|
||||
// TODO: Remove usage of fmt!
|
||||
fn next_portal_id(&self) -> String {
|
||||
format!(
|
||||
"__sqlx_portal_{}",
|
||||
self.next_portal_id.fetch_add(1, Ordering::AcqRel)
|
||||
)
|
||||
}
|
||||
|
||||
// TODO: Remove usage of fmt!
|
||||
fn next_statement_id(&self) -> String {
|
||||
format!(
|
||||
"__sqlx_statement_{}",
|
||||
self.next_statement_id.fetch_add(1, Ordering::AcqRel)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
enum Step {
|
||||
Command(u64),
|
||||
Row(PostgresRow),
|
||||
}
|
||||
|
||||
impl RawConnection for PostgresRawConnection {
|
||||
type Backend = Postgres;
|
||||
|
||||
#[inline]
|
||||
fn establish(url: &str) -> BoxFuture<crate::Result<Self>> {
|
||||
Box::pin(Self::establish(url))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn finalize<'c>(&'c mut self) -> BoxFuture<'c, crate::Result<()>> {
|
||||
Box::pin(self.finalize())
|
||||
}
|
||||
|
||||
fn execute<'c>(
|
||||
&'c mut self,
|
||||
query: &str,
|
||||
params: PostgresQueryParameters,
|
||||
) -> BoxFuture<'c, crate::Result<u64>> {
|
||||
self.execute(query, params, 1);
|
||||
|
||||
Box::pin(async move {
|
||||
let mut affected = 0;
|
||||
|
||||
while let Some(step) = self.step().await? {
|
||||
if let Step::Command(cnt) = step {
|
||||
affected = cnt;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(affected)
|
||||
})
|
||||
}
|
||||
|
||||
fn fetch<'c>(
|
||||
&'c mut self,
|
||||
query: &str,
|
||||
params: PostgresQueryParameters,
|
||||
) -> BoxStream<'c, crate::Result<PostgresRow>> {
|
||||
self.execute(query, params, 0);
|
||||
|
||||
Box::pin(async_stream::try_stream! {
|
||||
while let Some(step) = self.step().await? {
|
||||
if let Step::Row(row) = step {
|
||||
yield row;
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn fetch_optional<'c>(
|
||||
&'c mut self,
|
||||
query: &str,
|
||||
params: PostgresQueryParameters,
|
||||
) -> BoxFuture<'c, crate::Result<Option<PostgresRow>>> {
|
||||
self.execute(query, params, 1);
|
||||
|
||||
Box::pin(async move {
|
||||
let mut row: Option<PostgresRow> = None;
|
||||
|
||||
while let Some(step) = self.step().await? {
|
||||
if let Step::Row(r) = step {
|
||||
// This should only ever execute once because we used the
|
||||
// protocol-level limit
|
||||
debug_assert!(row.is_none());
|
||||
row = Some(r);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(row)
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -1,91 +0,0 @@
|
||||
use super::PostgresRawConnection;
|
||||
use crate::{
|
||||
error::Error,
|
||||
postgres::protocol::{Authentication, Message, PasswordMessage, StartupMessage},
|
||||
url::Url,
|
||||
};
|
||||
use std::io;
|
||||
|
||||
pub async fn establish<'a, 'b: 'a>(
|
||||
conn: &'a mut PostgresRawConnection,
|
||||
url: &'b Url,
|
||||
) -> Result<(), Error> {
|
||||
let user = url.username();
|
||||
let password = url.password().unwrap_or("");
|
||||
let database = url.database();
|
||||
|
||||
// See this doc for more runtime parameters
|
||||
// https://www.postgresql.org/docs/12/runtime-config-client.html
|
||||
let params = &[
|
||||
// FIXME: ConnectOptions user and database need to be required parameters and error
|
||||
// before they get here
|
||||
("user", user),
|
||||
("database", database),
|
||||
// Sets the display format for date and time values,
|
||||
// as well as the rules for interpreting ambiguous date input values.
|
||||
("DateStyle", "ISO, MDY"),
|
||||
// Sets the display format for interval values.
|
||||
("IntervalStyle", "iso_8601"),
|
||||
// Sets the time zone for displaying and interpreting time stamps.
|
||||
("TimeZone", "UTC"),
|
||||
// Adjust postgres to return percise values for floats
|
||||
// NOTE: This is default in postgres 12+
|
||||
("extra_float_digits", "3"),
|
||||
// Sets the client-side encoding (character set).
|
||||
("client_encoding", "UTF-8"),
|
||||
];
|
||||
|
||||
let message = StartupMessage { params };
|
||||
|
||||
conn.write(message);
|
||||
conn.stream.flush().await?;
|
||||
|
||||
while let Some(message) = conn.receive().await? {
|
||||
match message {
|
||||
Message::Authentication(auth) => {
|
||||
match *auth {
|
||||
Authentication::Ok => {
|
||||
// Do nothing. No password is needed to continue.
|
||||
}
|
||||
|
||||
Authentication::CleartextPassword => {
|
||||
// FIXME: Should error early (before send) if the user did not supply a password
|
||||
conn.write(PasswordMessage::Cleartext(password));
|
||||
|
||||
conn.stream.flush().await?;
|
||||
}
|
||||
|
||||
Authentication::Md5Password { salt } => {
|
||||
// FIXME: Should error early (before send) if the user did not supply a password
|
||||
conn.write(PasswordMessage::Md5 {
|
||||
password,
|
||||
user,
|
||||
salt,
|
||||
});
|
||||
|
||||
conn.stream.flush().await?;
|
||||
}
|
||||
|
||||
auth => {
|
||||
unimplemented!("received {:?} unimplemented authentication message", auth);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Message::BackendKeyData(body) => {
|
||||
conn.process_id = body.process_id();
|
||||
conn.secret_key = body.secret_key();
|
||||
}
|
||||
|
||||
Message::ReadyForQuery(_) => {
|
||||
break;
|
||||
}
|
||||
|
||||
message => {
|
||||
unimplemented!("received {:?} unimplemented message", message);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -1,31 +0,0 @@
|
||||
use super::PostgresRawConnection;
|
||||
use crate::{error::Error, postgres::protocol::Message};
|
||||
use std::io;
|
||||
|
||||
pub async fn execute(conn: &mut PostgresRawConnection) -> Result<u64, Error> {
|
||||
conn.stream.flush().await?;
|
||||
|
||||
let mut rows = 0;
|
||||
|
||||
while let Some(message) = conn.receive().await? {
|
||||
match message {
|
||||
Message::BindComplete | Message::ParseComplete | Message::DataRow(_) => {}
|
||||
|
||||
Message::CommandComplete(body) => {
|
||||
rows = body.affected_rows();
|
||||
}
|
||||
|
||||
Message::ReadyForQuery(_) => {
|
||||
// Successful completion of the whole cycle
|
||||
return Ok(rows);
|
||||
}
|
||||
|
||||
message => {
|
||||
unimplemented!("received {:?} unimplemented message", message);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// FIXME: This is an end-of-file error. How we should bubble this up here?
|
||||
unreachable!()
|
||||
}
|
||||
@ -1,37 +0,0 @@
|
||||
use super::{PostgresRawConnection, PostgresRow};
|
||||
use crate::{error::Error, postgres::protocol::Message};
|
||||
use futures_core::stream::Stream;
|
||||
use std::io;
|
||||
|
||||
pub fn fetch<'a>(
|
||||
conn: &'a mut PostgresRawConnection,
|
||||
) -> impl Stream<Item = Result<PostgresRow, Error>> + 'a {
|
||||
async_stream::try_stream! {
|
||||
conn.stream.flush().await.map_err(Error::from)?;
|
||||
|
||||
while let Some(message) = conn.receive().await? {
|
||||
match message {
|
||||
Message::BindComplete
|
||||
| Message::ParseComplete
|
||||
| Message::PortalSuspended
|
||||
| Message::CloseComplete
|
||||
| Message::CommandComplete(_) => {}
|
||||
|
||||
Message::DataRow(body) => {
|
||||
yield PostgresRow(body);
|
||||
}
|
||||
|
||||
Message::ReadyForQuery(_) => {
|
||||
return;
|
||||
}
|
||||
|
||||
message => {
|
||||
unimplemented!("received {:?} unimplemented message", message);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// FIXME: This is an end-of-file error. How we should bubble this up here?
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
@ -1,36 +0,0 @@
|
||||
use super::{PostgresRawConnection, PostgresRow};
|
||||
use crate::{error::Error, postgres::protocol::Message};
|
||||
use std::io;
|
||||
|
||||
pub async fn fetch_optional<'a>(
|
||||
conn: &'a mut PostgresRawConnection,
|
||||
) -> Result<Option<PostgresRow>, Error> {
|
||||
conn.stream.flush().await?;
|
||||
|
||||
let mut row: Option<PostgresRow> = None;
|
||||
|
||||
while let Some(message) = conn.receive().await? {
|
||||
match message {
|
||||
Message::BindComplete
|
||||
| Message::ParseComplete
|
||||
| Message::PortalSuspended
|
||||
| Message::CloseComplete
|
||||
| Message::CommandComplete(_) => {}
|
||||
|
||||
Message::DataRow(body) => {
|
||||
row = Some(PostgresRow(body));
|
||||
}
|
||||
|
||||
Message::ReadyForQuery(_) => {
|
||||
return Ok(row);
|
||||
}
|
||||
|
||||
message => {
|
||||
unimplemented!("received {:?} unimplemented message", message);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// FIXME: This is an end-of-file error. How we should bubble this up here?
|
||||
unreachable!()
|
||||
}
|
||||
@ -1,199 +0,0 @@
|
||||
use super::{
|
||||
protocol::{self, Decode, Encode, Message, Terminate},
|
||||
Postgres, PostgresQueryParameters, PostgresRow,
|
||||
};
|
||||
use crate::{connection::RawConnection, error::Error, io::BufStream, query::QueryParameters};
|
||||
// use bytes::{BufMut, BytesMut};
|
||||
use crate::{io::Buf, url::Url};
|
||||
use byteorder::NetworkEndian;
|
||||
use futures_core::{future::BoxFuture, stream::BoxStream};
|
||||
use std::{
|
||||
io,
|
||||
net::{IpAddr, Shutdown, SocketAddr},
|
||||
};
|
||||
use tokio::{
|
||||
io::{AsyncReadExt, AsyncWriteExt},
|
||||
net::TcpStream,
|
||||
};
|
||||
|
||||
mod establish;
|
||||
mod execute;
|
||||
mod fetch;
|
||||
mod fetch_optional;
|
||||
|
||||
pub struct PostgresRawConnection {
|
||||
stream: BufStream<TcpStream>,
|
||||
|
||||
// Process ID of the Backend
|
||||
process_id: u32,
|
||||
|
||||
// Backend-unique key to use to send a cancel query message to the server
|
||||
secret_key: u32,
|
||||
}
|
||||
|
||||
impl PostgresRawConnection {
|
||||
async fn establish(url: &str) -> Result<Self, Error> {
|
||||
let url = Url::parse(url);
|
||||
let stream = TcpStream::connect(&url.address(5432)).await?;
|
||||
|
||||
let mut conn = Self {
|
||||
stream: BufStream::new(stream),
|
||||
process_id: 0,
|
||||
secret_key: 0,
|
||||
};
|
||||
|
||||
establish::establish(&mut conn, &url).await?;
|
||||
|
||||
Ok(conn)
|
||||
}
|
||||
|
||||
async fn finalize(&mut self) -> Result<(), Error> {
|
||||
self.write(Terminate);
|
||||
self.stream.flush().await?;
|
||||
self.stream.close().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Wait and return the next message to be received from Postgres.
|
||||
async fn receive(&mut self) -> Result<Option<Message>, Error> {
|
||||
loop {
|
||||
// Read the message header (id + len)
|
||||
let mut header = ret_if_none!(self.stream.peek(5).await?);
|
||||
log::trace!("recv:header {:?}", bytes::Bytes::from(&*header));
|
||||
|
||||
let id = header.get_u8()?;
|
||||
let len = (header.get_u32::<NetworkEndian>()? - 4) as usize;
|
||||
|
||||
// Read the message body
|
||||
self.stream.consume(5);
|
||||
let body = ret_if_none!(self.stream.peek(len).await?);
|
||||
log::trace!("recv {:?}", bytes::Bytes::from(&*body));
|
||||
|
||||
let message = match id {
|
||||
b'N' | b'E' => Message::Response(Box::new(protocol::Response::decode(body)?)),
|
||||
b'D' => Message::DataRow(protocol::DataRow::decode(body)?),
|
||||
b'S' => {
|
||||
Message::ParameterStatus(Box::new(protocol::ParameterStatus::decode(body)?))
|
||||
}
|
||||
b'Z' => Message::ReadyForQuery(protocol::ReadyForQuery::decode(body)?),
|
||||
b'R' => Message::Authentication(Box::new(protocol::Authentication::decode(body)?)),
|
||||
b'K' => Message::BackendKeyData(protocol::BackendKeyData::decode(body)?),
|
||||
b'C' => Message::CommandComplete(protocol::CommandComplete::decode(body)?),
|
||||
b'A' => Message::NotificationResponse(Box::new(
|
||||
protocol::NotificationResponse::decode(body)?,
|
||||
)),
|
||||
b'1' => Message::ParseComplete,
|
||||
b'2' => Message::BindComplete,
|
||||
b'3' => Message::CloseComplete,
|
||||
b'n' => Message::NoData,
|
||||
b's' => Message::PortalSuspended,
|
||||
b't' => Message::ParameterDescription(Box::new(
|
||||
protocol::ParameterDescription::decode(body)?,
|
||||
)),
|
||||
|
||||
_ => unimplemented!("unknown message id: {}", id as char),
|
||||
};
|
||||
|
||||
self.stream.consume(len);
|
||||
|
||||
match message {
|
||||
Message::ParameterStatus(_body) => {
|
||||
// TODO: not sure what to do with these yet
|
||||
}
|
||||
|
||||
Message::Response(_body) => {
|
||||
// TODO: Transform Errors+ into an error type and return
|
||||
// TODO: Log all others
|
||||
}
|
||||
|
||||
message => {
|
||||
return Ok(Some(message));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn write(&mut self, message: impl Encode) {
|
||||
let pos = self.stream.buffer_mut().len();
|
||||
|
||||
message.encode(self.stream.buffer_mut());
|
||||
|
||||
log::trace!(
|
||||
"send {:?}",
|
||||
bytes::Bytes::from(&self.stream.buffer_mut()[pos..])
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl RawConnection for PostgresRawConnection {
|
||||
type Backend = Postgres;
|
||||
|
||||
#[inline]
|
||||
fn establish(url: &str) -> BoxFuture<Result<Self, Error>> {
|
||||
Box::pin(PostgresRawConnection::establish(url))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn finalize<'c>(&'c mut self) -> BoxFuture<'c, Result<(), Error>> {
|
||||
Box::pin(self.finalize())
|
||||
}
|
||||
|
||||
fn execute<'c>(
|
||||
&'c mut self,
|
||||
query: &str,
|
||||
params: PostgresQueryParameters,
|
||||
) -> BoxFuture<'c, Result<u64, Error>> {
|
||||
finish(self, query, params, 0);
|
||||
|
||||
Box::pin(execute::execute(self))
|
||||
}
|
||||
|
||||
fn fetch<'c>(
|
||||
&'c mut self,
|
||||
query: &str,
|
||||
params: PostgresQueryParameters,
|
||||
) -> BoxStream<'c, Result<PostgresRow, Error>> {
|
||||
finish(self, query, params, 0);
|
||||
|
||||
Box::pin(fetch::fetch(self))
|
||||
}
|
||||
|
||||
fn fetch_optional<'c>(
|
||||
&'c mut self,
|
||||
query: &str,
|
||||
params: PostgresQueryParameters,
|
||||
) -> BoxFuture<'c, Result<Option<PostgresRow>, Error>> {
|
||||
finish(self, query, params, 1);
|
||||
|
||||
Box::pin(fetch_optional::fetch_optional(self))
|
||||
}
|
||||
}
|
||||
|
||||
fn finish(
|
||||
conn: &mut PostgresRawConnection,
|
||||
query: &str,
|
||||
params: PostgresQueryParameters,
|
||||
limit: i32,
|
||||
) {
|
||||
conn.write(protocol::Parse {
|
||||
portal: "",
|
||||
query,
|
||||
param_types: &*params.types,
|
||||
});
|
||||
|
||||
conn.write(protocol::Bind {
|
||||
portal: "",
|
||||
statement: "",
|
||||
formats: &[1], // [BINARY]
|
||||
// TODO: Early error if there is more than i16
|
||||
values_len: params.types.len() as i16,
|
||||
values: &*params.buf,
|
||||
result_formats: &[1], // [BINARY]
|
||||
});
|
||||
|
||||
// TODO: Make limit be 1 for fetch_optional
|
||||
conn.write(protocol::Execute { portal: "", limit });
|
||||
|
||||
conn.write(protocol::Sync);
|
||||
}
|
||||
11
src/postgres/error.rs
Normal file
11
src/postgres/error.rs
Normal file
@ -0,0 +1,11 @@
|
||||
use super::protocol::Response;
|
||||
use crate::error::DbError;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PostgresError(pub(super) Box<Response>);
|
||||
|
||||
impl DbError for PostgresError {
|
||||
fn message(&self) -> &str {
|
||||
self.0.message()
|
||||
}
|
||||
}
|
||||
@ -1,11 +1,74 @@
|
||||
mod backend;
|
||||
mod connection;
|
||||
mod error;
|
||||
mod protocol;
|
||||
mod query;
|
||||
mod row;
|
||||
pub mod types;
|
||||
|
||||
pub use self::{
|
||||
backend::Postgres, connection::PostgresRawConnection, query::PostgresQueryParameters,
|
||||
row::PostgresRow,
|
||||
backend::Postgres, connection::PostgresRawConnection, error::PostgresError,
|
||||
query::PostgresQueryParameters, row::PostgresRow,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{Postgres, PostgresRawConnection};
|
||||
use crate::connection::{Connection, RawConnection};
|
||||
use futures_util::TryStreamExt;
|
||||
|
||||
const DATABASE_URL: &str = "postgres://postgres@127.0.0.1:5432/";
|
||||
|
||||
#[tokio::test]
|
||||
async fn it_connects() {
|
||||
let mut conn = PostgresRawConnection::establish(DATABASE_URL)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
conn.finalize().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn it_fails_on_connect_with_an_unknown_user() {
|
||||
let res = PostgresRawConnection::establish("postgres://not_a_user@127.0.0.1:5432/").await;
|
||||
|
||||
match res {
|
||||
Err(crate::Error::Database(err)) => {
|
||||
assert_eq!(err.message(), "role \"not_a_user\" does not exist");
|
||||
}
|
||||
|
||||
_ => panic!("unexpected result"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn it_fails_on_connect_with_an_unknown_database() {
|
||||
let res =
|
||||
PostgresRawConnection::establish("postgres://postgres@127.0.0.1:5432/fdggsdfgsdaf")
|
||||
.await;
|
||||
|
||||
match res {
|
||||
Err(crate::Error::Database(err)) => {
|
||||
assert_eq!(err.message(), "database \"fdggsdfgsdaf\" does not exist");
|
||||
}
|
||||
|
||||
_ => panic!("unexpected result"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn it_fetches_tuples() {
|
||||
let conn = Connection::<Postgres>::establish(DATABASE_URL)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let roles: Vec<(String, bool)> = crate::query("SELECT rolname, rolsuper FROM pg_roles")
|
||||
.fetch(&conn)
|
||||
.try_collect()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Sanity check to be sure we did indeed fetch tuples
|
||||
assert!(roles.binary_search(&("postgres".to_string(), true)).is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
@ -3,9 +3,9 @@ use std::net::{IpAddr, SocketAddr};
|
||||
pub struct Url(url::Url);
|
||||
|
||||
impl Url {
|
||||
pub fn parse(url: &str) -> Self {
|
||||
pub fn parse(url: &str) -> crate::Result<Self> {
|
||||
// TODO: Handle parse errors
|
||||
Url(url::Url::parse(url).unwrap())
|
||||
Ok(Url(url::Url::parse(url).unwrap()))
|
||||
}
|
||||
|
||||
pub fn host(&self) -> &str {
|
||||
@ -16,7 +16,7 @@ impl Url {
|
||||
self.0.port().unwrap_or(default)
|
||||
}
|
||||
|
||||
pub fn address(&self, default_port: u16) -> SocketAddr {
|
||||
pub fn resolve(&self, default_port: u16) -> SocketAddr {
|
||||
// TODO: DNS
|
||||
let host: IpAddr = self.host().parse().unwrap();
|
||||
let port = self.port(default_port);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user