mirror of
https://github.com/launchbadge/sqlx.git
synced 2025-12-29 21:00:54 +00:00
Extract BufStream from PostgresRawConnection
This commit is contained in:
parent
6c98ba01b8
commit
c33530b25c
@ -82,14 +82,14 @@ async fn insert(pool: &PostgresPool, count: usize) -> Result<(), sqlx::Error> {
|
||||
r#"
|
||||
INSERT INTO contacts (name, username, password, email, phone)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
"#,
|
||||
"#,
|
||||
(
|
||||
contact.name,
|
||||
contact.username,
|
||||
contact.password,
|
||||
contact.email,
|
||||
contact.phone,
|
||||
)
|
||||
),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
@ -115,14 +115,16 @@ async fn select(pool: &PostgresPool, iterations: usize) -> Result<(), sqlx::Erro
|
||||
|
||||
for _ in 0..iterations {
|
||||
// TODO: Once we have FromRow derives we can replace this with Vec<Contact>
|
||||
let contacts: Vec<(String, String, String, String, String)> = pool.fetch(
|
||||
r#"
|
||||
let contacts: Vec<(String, String, String, String, String)> = pool
|
||||
.fetch(
|
||||
r#"
|
||||
SELECT name, username, password, email, phone
|
||||
FROM contacts
|
||||
"#, (),
|
||||
)
|
||||
.try_collect()
|
||||
.await?;
|
||||
"#,
|
||||
(),
|
||||
)
|
||||
.try_collect()
|
||||
.await?;
|
||||
|
||||
rows = contacts.len();
|
||||
}
|
||||
|
||||
@ -76,6 +76,7 @@ SELECT id, text
|
||||
FROM tasks
|
||||
WHERE done_at IS NULL
|
||||
"#,
|
||||
(),
|
||||
)
|
||||
.try_for_each(|(id, text): (i64, String)| {
|
||||
// language=text
|
||||
@ -89,7 +90,8 @@ WHERE done_at IS NULL
|
||||
}
|
||||
|
||||
async fn add_task(conn: &mut Connection<Postgres>, text: &str) -> Fallible<()> {
|
||||
conn.execute("INSERT INTO tasks ( text ) VALUES ( $1 )", (text,)).await?;
|
||||
conn.execute("INSERT INTO tasks ( text ) VALUES ( $1 )", (text,))
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -1,5 +1,9 @@
|
||||
use crate::{
|
||||
backend::Backend, error::Error, executor::Executor, query::{QueryParameters, IntoQueryParameters}, row::FromSqlRow,
|
||||
backend::Backend,
|
||||
error::Error,
|
||||
executor::Executor,
|
||||
query::{IntoQueryParameters, QueryParameters},
|
||||
row::FromSqlRow,
|
||||
};
|
||||
use crossbeam_queue::SegQueue;
|
||||
use crossbeam_utils::atomic::AtomicCell;
|
||||
|
||||
@ -40,6 +40,13 @@ impl Display for Error {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<io::Error> for Error {
|
||||
#[inline]
|
||||
fn from(err: io::Error) -> Self {
|
||||
Error::Io(err)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Define a RawError type for the database backend for forwarding error information
|
||||
|
||||
/// An error that was returned by the database backend.
|
||||
|
||||
102
src/io/buf_stream.rs
Normal file
102
src/io/buf_stream.rs
Normal file
@ -0,0 +1,102 @@
|
||||
use bytes::{BufMut, BytesMut};
|
||||
use std::io;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
|
||||
pub struct BufStream<S> {
|
||||
stream: S,
|
||||
|
||||
// Have we reached end-of-file (been disconnected)
|
||||
stream_eof: bool,
|
||||
|
||||
// Buffer used when sending outgoing messages
|
||||
wbuf: Vec<u8>,
|
||||
|
||||
// Buffer used when reading incoming messages
|
||||
rbuf: BytesMut,
|
||||
}
|
||||
|
||||
impl<S> BufStream<S>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
pub fn new(stream: S) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
stream_eof: false,
|
||||
wbuf: Vec::with_capacity(1 * 1024),
|
||||
rbuf: BytesMut::with_capacity(8 * 1024),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn close(&mut self) -> io::Result<()> {
|
||||
self.stream.shutdown().await
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn buffer_mut(&mut self) -> &mut Vec<u8> {
|
||||
&mut self.wbuf
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub async fn flush(&mut self) -> io::Result<()> {
|
||||
if self.wbuf.len() > 0 {
|
||||
self.stream.write_all(&self.wbuf).await?;
|
||||
self.wbuf.clear();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn consume(&mut self, cnt: usize) {
|
||||
self.rbuf.advance(cnt);
|
||||
}
|
||||
|
||||
pub async fn peek(&mut self, cnt: usize) -> io::Result<Option<&[u8]>> {
|
||||
loop {
|
||||
// Reaching end-of-file (read 0 bytes) will continuously
|
||||
// return None from all future calls to read
|
||||
if self.stream_eof {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// If we have enough bytes in our read buffer,
|
||||
// return immediately
|
||||
if self.rbuf.len() >= cnt {
|
||||
return Ok(Some(&self.rbuf[..cnt]));
|
||||
}
|
||||
|
||||
if self.rbuf.capacity() < cnt {
|
||||
// Ask for exactly how much we need with a lower bound of 32-bytes
|
||||
let needed = (cnt - self.rbuf.capacity()).max(32);
|
||||
self.rbuf.reserve(needed);
|
||||
}
|
||||
|
||||
// SAFE: Read data in directly to buffer without zero-initializing the data.
|
||||
// Postgres is a self-describing format and the TCP frames encode
|
||||
// length headers. We will never attempt to decode more than we
|
||||
// received.
|
||||
let n = self.stream.read(unsafe { self.rbuf.bytes_mut() }).await?;
|
||||
|
||||
// SAFE: After we read in N bytes, we can tell the buffer that it actually
|
||||
// has that many bytes MORE for the decode routines to look at
|
||||
unsafe { self.rbuf.advance_mut(n) }
|
||||
|
||||
if n == 0 {
|
||||
self.stream_eof = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Return `Ok(None)` immediately from a function if the wrapped value is `None`
|
||||
macro_rules! ret_if_none {
|
||||
($val:expr) => {
|
||||
match $val {
|
||||
Some(val) => val,
|
||||
None => {
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
4
src/io/mod.rs
Normal file
4
src/io/mod.rs
Normal file
@ -0,0 +1,4 @@
|
||||
#[macro_use]
|
||||
mod buf_stream;
|
||||
|
||||
pub use self::buf_stream::BufStream;
|
||||
@ -11,6 +11,9 @@
|
||||
#[macro_use]
|
||||
mod macros;
|
||||
|
||||
#[macro_use]
|
||||
mod io;
|
||||
|
||||
pub mod backend;
|
||||
pub mod deserialize;
|
||||
|
||||
|
||||
@ -1,6 +1,10 @@
|
||||
use crate::{
|
||||
backend::Backend, connection::RawConnection, error::Error, executor::Executor,
|
||||
query::{QueryParameters, IntoQueryParameters}, row::FromSqlRow,
|
||||
backend::Backend,
|
||||
connection::RawConnection,
|
||||
error::Error,
|
||||
executor::Executor,
|
||||
query::{IntoQueryParameters, QueryParameters},
|
||||
row::FromSqlRow,
|
||||
};
|
||||
use crossbeam_queue::{ArrayQueue, SegQueue};
|
||||
use futures_channel::oneshot;
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
use super::{
|
||||
protocol::{self, Encode, Message, Terminate},
|
||||
protocol::{self, Decode, Encode, Message, Terminate},
|
||||
Postgres, PostgresQueryParameters, PostgresRow,
|
||||
};
|
||||
use crate::{connection::RawConnection, error::Error, query::QueryParameters};
|
||||
use bytes::{BufMut, BytesMut};
|
||||
use crate::{connection::RawConnection, error::Error, io::BufStream, query::QueryParameters};
|
||||
// use bytes::{BufMut, BytesMut};
|
||||
use super::protocol::Buf;
|
||||
use futures_core::{future::BoxFuture, stream::BoxStream};
|
||||
use std::{
|
||||
io,
|
||||
@ -21,20 +22,7 @@ mod fetch;
|
||||
mod fetch_optional;
|
||||
|
||||
pub struct PostgresRawConnection {
|
||||
stream: TcpStream,
|
||||
|
||||
// Do we think that there is data in the read buffer to be decoded
|
||||
stream_readable: bool,
|
||||
|
||||
// Have we reached end-of-file (been disconnected)
|
||||
stream_eof: bool,
|
||||
|
||||
// Buffer used when sending outgoing messages
|
||||
pub(super) wbuf: Vec<u8>,
|
||||
|
||||
// Buffer used when reading incoming messages
|
||||
// TODO: Evaluate if we _really_ want to use BytesMut here
|
||||
rbuf: BytesMut,
|
||||
stream: BufStream<TcpStream>,
|
||||
|
||||
// Process ID of the Backend
|
||||
process_id: u32,
|
||||
@ -58,11 +46,7 @@ impl PostgresRawConnection {
|
||||
let stream = TcpStream::connect(&addr).await.map_err(Error::Io)?;
|
||||
|
||||
let mut conn = Self {
|
||||
wbuf: Vec::with_capacity(1024),
|
||||
rbuf: BytesMut::with_capacity(1024 * 8),
|
||||
stream,
|
||||
stream_readable: false,
|
||||
stream_eof: false,
|
||||
stream: BufStream::new(stream),
|
||||
process_id: 0,
|
||||
secret_key: 0,
|
||||
};
|
||||
@ -74,8 +58,8 @@ impl PostgresRawConnection {
|
||||
|
||||
async fn finalize(&mut self) -> Result<(), Error> {
|
||||
self.write(Terminate);
|
||||
self.flush().await?;
|
||||
self.stream.shutdown(Shutdown::Both).map_err(Error::Io)?;
|
||||
self.stream.flush().await?;
|
||||
self.stream.close().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -83,69 +67,64 @@ impl PostgresRawConnection {
|
||||
// Wait and return the next message to be received from Postgres.
|
||||
async fn receive(&mut self) -> Result<Option<Message>, Error> {
|
||||
loop {
|
||||
if self.stream_eof {
|
||||
// Reached end-of-file on a previous read call.
|
||||
return Ok(None);
|
||||
}
|
||||
// Read the message header (id + len)
|
||||
let mut header = ret_if_none!(self.stream.peek(5).await?);
|
||||
let id = header.get_int_1()?;
|
||||
let len = (header.get_int_4()? - 4) as usize;
|
||||
|
||||
if self.stream_readable {
|
||||
loop {
|
||||
match Message::decode(&mut self.rbuf) {
|
||||
Some(Message::ParameterStatus(_body)) => {
|
||||
// TODO: not sure what to do with these yet
|
||||
}
|
||||
// Read the message body
|
||||
self.stream.consume(5);
|
||||
let body = ret_if_none!(self.stream.peek(len).await?);
|
||||
|
||||
Some(Message::Response(_body)) => {
|
||||
// TODO: Transform Errors+ into an error type and return
|
||||
// TODO: Log all others
|
||||
}
|
||||
let message = match id {
|
||||
b'N' | b'E' => Message::Response(Box::new(protocol::Response::decode(body))),
|
||||
b'D' => Message::DataRow(Box::new(protocol::DataRow::decode(body))),
|
||||
b'S' => Message::ParameterStatus(Box::new(protocol::ParameterStatus::decode(body))),
|
||||
b'Z' => Message::ReadyForQuery(protocol::ReadyForQuery::decode(body)),
|
||||
b'R' => Message::Authentication(Box::new(protocol::Authentication::decode(body))),
|
||||
b'K' => Message::BackendKeyData(protocol::BackendKeyData::decode(body)),
|
||||
b'T' => Message::RowDescription(Box::new(protocol::RowDescription::decode(body))),
|
||||
b'C' => Message::CommandComplete(protocol::CommandComplete::decode(body)),
|
||||
b'A' => Message::NotificationResponse(Box::new(
|
||||
protocol::NotificationResponse::decode(body),
|
||||
)),
|
||||
b'1' => Message::ParseComplete,
|
||||
b'2' => Message::BindComplete,
|
||||
b'3' => Message::CloseComplete,
|
||||
b'n' => Message::NoData,
|
||||
b's' => Message::PortalSuspended,
|
||||
b't' => Message::ParameterDescription(Box::new(
|
||||
protocol::ParameterDescription::decode(body),
|
||||
)),
|
||||
|
||||
Some(message) => {
|
||||
return Ok(Some(message));
|
||||
}
|
||||
_ => unimplemented!("unknown message id: {}", id as char),
|
||||
};
|
||||
|
||||
None => {
|
||||
// Not enough data in the read buffer to parse a message
|
||||
self.stream_readable = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
self.stream.consume(len);
|
||||
|
||||
match message {
|
||||
Message::ParameterStatus(_body) => {
|
||||
// TODO: not sure what to do with these yet
|
||||
}
|
||||
|
||||
Message::Response(_body) => {
|
||||
// TODO: Transform Errors+ into an error type and return
|
||||
// TODO: Log all others
|
||||
}
|
||||
|
||||
message => {
|
||||
return Ok(Some(message));
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure there is at least 32-bytes of space available
|
||||
// in the read buffer so we can safely detect end-of-file
|
||||
self.rbuf.reserve(32);
|
||||
|
||||
// SAFE: Read data in directly to buffer without zero-initializing the data.
|
||||
// Postgres is a self-describing format and the TCP frames encode
|
||||
// length headers. We will never attempt to decode more than we
|
||||
// received.
|
||||
let n = self
|
||||
.stream
|
||||
.read(unsafe { self.rbuf.bytes_mut() })
|
||||
.await
|
||||
.map_err(Error::Io)?;
|
||||
|
||||
// SAFE: After we read in N bytes, we can tell the buffer that it actually
|
||||
// has that many bytes MORE for the decode routines to look at
|
||||
unsafe { self.rbuf.advance_mut(n) }
|
||||
|
||||
if n == 0 {
|
||||
self.stream_eof = true;
|
||||
}
|
||||
|
||||
self.stream_readable = true;
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn write(&mut self, message: impl Encode) {
|
||||
message.encode(&mut self.wbuf);
|
||||
message.encode(self.stream.buffer_mut());
|
||||
}
|
||||
|
||||
async fn flush(&mut self) -> Result<(), Error> {
|
||||
self.stream.write_all(&self.wbuf).await.map_err(Error::Io)?;
|
||||
self.wbuf.clear();
|
||||
self.stream.flush().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -195,7 +174,12 @@ impl RawConnection for PostgresRawConnection {
|
||||
}
|
||||
}
|
||||
|
||||
fn finish(conn: &mut PostgresRawConnection, query: &str, params: PostgresQueryParameters, limit: i32) {
|
||||
fn finish(
|
||||
conn: &mut PostgresRawConnection,
|
||||
query: &str,
|
||||
params: PostgresQueryParameters,
|
||||
limit: i32,
|
||||
) {
|
||||
conn.write(protocol::Parse {
|
||||
portal: "",
|
||||
query,
|
||||
@ -213,10 +197,7 @@ fn finish(conn: &mut PostgresRawConnection, query: &str, params: PostgresQueryPa
|
||||
});
|
||||
|
||||
// TODO: Make limit be 1 for fetch_optional
|
||||
conn.write(protocol::Execute {
|
||||
portal: "",
|
||||
limit,
|
||||
});
|
||||
conn.write(protocol::Execute { portal: "", limit });
|
||||
|
||||
conn.write(protocol::Sync);
|
||||
}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
use memchr::memchr;
|
||||
use std::str;
|
||||
use std::{convert::TryInto, io, str};
|
||||
|
||||
pub trait Decode {
|
||||
fn decode(src: &[u8]) -> Self
|
||||
@ -14,3 +14,50 @@ pub(crate) fn get_str(src: &[u8]) -> &str {
|
||||
|
||||
unsafe { str::from_utf8_unchecked(buf) }
|
||||
}
|
||||
|
||||
pub trait Buf {
|
||||
fn advance(&mut self, cnt: usize);
|
||||
|
||||
// An n-bit integer in network byte order
|
||||
fn get_int_1(&mut self) -> io::Result<u8>;
|
||||
fn get_int_4(&mut self) -> io::Result<u32>;
|
||||
|
||||
// A null-terminated string
|
||||
fn get_str(&mut self) -> io::Result<&str>;
|
||||
}
|
||||
|
||||
impl<'a> Buf for &'a [u8] {
|
||||
#[inline]
|
||||
fn advance(&mut self, cnt: usize) {
|
||||
*self = &self[cnt..];
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn get_int_1(&mut self) -> io::Result<u8> {
|
||||
let val = self[0];
|
||||
|
||||
self.advance(1);
|
||||
|
||||
Ok(val)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn get_int_4(&mut self) -> io::Result<u32> {
|
||||
let val: [u8; 4] = (*self)
|
||||
.try_into()
|
||||
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
|
||||
|
||||
self.advance(4);
|
||||
|
||||
Ok(u32::from_be_bytes(val))
|
||||
}
|
||||
|
||||
fn get_str(&mut self) -> io::Result<&str> {
|
||||
let end = memchr(b'\0', &*self).ok_or(io::ErrorKind::InvalidData)?;
|
||||
let buf = &self[..end];
|
||||
|
||||
self.advance(end);
|
||||
|
||||
str::from_utf8(buf).map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))
|
||||
}
|
||||
}
|
||||
|
||||
@ -25,61 +25,3 @@ pub enum Message {
|
||||
PortalSuspended,
|
||||
ParameterDescription(Box<ParameterDescription>),
|
||||
}
|
||||
|
||||
impl Message {
|
||||
// FIXME: `Message::decode` shares the name of the remaining message type `::decode` despite being very
|
||||
// different
|
||||
pub fn decode(src: &mut BytesMut) -> Option<Self>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
if src.len() < 5 {
|
||||
// No message is less than 5 bytes
|
||||
return None;
|
||||
}
|
||||
|
||||
let token = src[0];
|
||||
if token == 0 {
|
||||
// FIXME: Handle end-of-stream
|
||||
panic!("unexpectede end-of-stream");
|
||||
}
|
||||
|
||||
// FIXME: What happens if len(u32) < len(usize) ?
|
||||
let len = BigEndian::read_u32(&src[1..5]) as usize;
|
||||
|
||||
if src.len() >= (len + 1) {
|
||||
let window = &src[5..=len];
|
||||
|
||||
let message = match token {
|
||||
b'N' | b'E' => Message::Response(Box::new(Response::decode(window))),
|
||||
b'D' => Message::DataRow(Box::new(DataRow::decode(window))),
|
||||
b'S' => Message::ParameterStatus(Box::new(ParameterStatus::decode(window))),
|
||||
b'Z' => Message::ReadyForQuery(ReadyForQuery::decode(window)),
|
||||
b'R' => Message::Authentication(Box::new(Authentication::decode(window))),
|
||||
b'K' => Message::BackendKeyData(BackendKeyData::decode(window)),
|
||||
b'T' => Message::RowDescription(Box::new(RowDescription::decode(window))),
|
||||
b'C' => Message::CommandComplete(CommandComplete::decode(window)),
|
||||
b'A' => {
|
||||
Message::NotificationResponse(Box::new(NotificationResponse::decode(window)))
|
||||
}
|
||||
b'1' => Message::ParseComplete,
|
||||
b'2' => Message::BindComplete,
|
||||
b'3' => Message::CloseComplete,
|
||||
b'n' => Message::NoData,
|
||||
b's' => Message::PortalSuspended,
|
||||
b't' => {
|
||||
Message::ParameterDescription(Box::new(ParameterDescription::decode(window)))
|
||||
}
|
||||
|
||||
_ => unimplemented!("decode not implemented for token: {}", token as char),
|
||||
};
|
||||
|
||||
src.advance(len + 1);
|
||||
|
||||
Some(message)
|
||||
} else {
|
||||
// We don't have enough in the stream yet
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -59,9 +59,16 @@ mod response;
|
||||
mod row_description;
|
||||
|
||||
pub use self::{
|
||||
authentication::Authentication, backend_key_data::BackendKeyData,
|
||||
command_complete::CommandComplete, data_row::DataRow, decode::Decode, message::Message,
|
||||
notification_response::NotificationResponse, parameter_description::ParameterDescription,
|
||||
parameter_status::ParameterStatus, ready_for_query::ReadyForQuery, response::Response,
|
||||
authentication::Authentication,
|
||||
backend_key_data::BackendKeyData,
|
||||
command_complete::CommandComplete,
|
||||
data_row::DataRow,
|
||||
decode::{Buf, Decode},
|
||||
message::Message,
|
||||
notification_response::NotificationResponse,
|
||||
parameter_description::ParameterDescription,
|
||||
parameter_status::ParameterStatus,
|
||||
ready_for_query::ReadyForQuery,
|
||||
response::Response,
|
||||
row_description::RowDescription,
|
||||
};
|
||||
|
||||
17
src/query.rs
17
src/query.rs
@ -18,7 +18,10 @@ pub trait QueryParameters: Send {
|
||||
T: ToSql<Self::Backend>;
|
||||
}
|
||||
|
||||
pub trait IntoQueryParameters<DB> where DB: Backend {
|
||||
pub trait IntoQueryParameters<DB>
|
||||
where
|
||||
DB: Backend,
|
||||
{
|
||||
fn into(self) -> DB::QueryParameters;
|
||||
}
|
||||
|
||||
@ -26,9 +29,9 @@ pub trait IntoQueryParameters<DB> where DB: Backend {
|
||||
macro_rules! impl_into_query_parameters {
|
||||
|
||||
($( ($idx:tt) -> $T:ident );+;) => {
|
||||
impl<$($T,)+ DB> IntoQueryParameters<DB> for ($($T,)+)
|
||||
where
|
||||
DB: Backend,
|
||||
impl<$($T,)+ DB> IntoQueryParameters<DB> for ($($T,)+)
|
||||
where
|
||||
DB: Backend,
|
||||
$(DB: crate::types::HasSqlType<$T>,)+
|
||||
$($T: crate::serialize::ToSql<DB>,)+
|
||||
{
|
||||
@ -41,9 +44,9 @@ macro_rules! impl_into_query_parameters {
|
||||
};
|
||||
}
|
||||
|
||||
impl<DB> IntoQueryParameters<DB> for ()
|
||||
where
|
||||
DB: Backend,
|
||||
impl<DB> IntoQueryParameters<DB> for ()
|
||||
where
|
||||
DB: Backend,
|
||||
{
|
||||
fn into(self) -> DB::QueryParameters {
|
||||
DB::QueryParameters::new()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user