postgres: Stream::read -> Stream::receive and extract "just reading" to Stream::read

This commit is contained in:
Ryan Leckey
2020-03-16 18:30:45 -07:00
parent 0ecacfaf1d
commit b80080a95b
5 changed files with 56 additions and 37 deletions

View File

@@ -131,7 +131,7 @@ async fn startup(stream: &mut PgStream, url: &Url) -> crate::Result<BackendKeyDa
};
loop {
match stream.read().await? {
match stream.receive().await? {
Message::Authentication => match Authentication::read(stream.buffer())? {
Authentication::Ok => {
// do nothing. no password is needed to continue.

View File

@@ -77,7 +77,7 @@ async fn expect_desc(
conn: &mut PgConnection,
) -> crate::Result<(HashMap<Box<str>, usize>, Vec<TypeFormat>)> {
let description: Option<_> = loop {
match conn.stream.read().await? {
match conn.stream.receive().await? {
Message::ParseComplete | Message::BindComplete => {}
Message::RowDescription => {
@@ -148,7 +148,7 @@ async fn next<'a, 'c: 'a, 'q: 'a>(
}
loop {
match conn.stream.read().await? {
match conn.stream.receive().await? {
// Indicates that a phase of the extended query flow has completed
// We as SQLx don't generally care as long as it is happening
Message::ParseComplete | Message::BindComplete => {}

View File

@@ -73,7 +73,7 @@ impl PgConnection {
if !self.is_ready {
loop {
if let Message::ReadyForQuery = self.stream.read().await? {
if let Message::ReadyForQuery = self.stream.receive().await? {
// we are now ready to go
self.is_ready = true;
break;
@@ -136,7 +136,7 @@ impl PgConnection {
Ok(statement)
}
async fn describe<'e, 'q: 'e>(
async fn do_describe<'e, 'q: 'e>(
&'e mut self,
query: &'q str,
) -> crate::Result<Describe<Postgres>> {
@@ -150,7 +150,7 @@ impl PgConnection {
self.stream.flush().await?;
let params = loop {
match self.stream.read().await? {
match self.stream.receive().await? {
Message::ParseComplete => {}
Message::ParameterDescription => {
@@ -167,7 +167,7 @@ impl PgConnection {
};
};
let result = match self.stream.read().await? {
let result = match self.stream.receive().await? {
Message::NoData => None,
Message::RowDescription => Some(RowDescription::read(self.stream.buffer())?),
@@ -329,7 +329,7 @@ impl PgConnection {
let mut rows = 0;
loop {
match self.stream.read().await? {
match self.stream.receive().await? {
Message::ParseComplete
| Message::BindComplete
| Message::NoData
@@ -397,7 +397,7 @@ impl Executor for super::PgConnection {
where
E: Execute<'q, Self::Database>,
{
Box::pin(async move { self.describe(query.into_parts().0).await })
Box::pin(async move { self.do_describe(query.into_parts().0).await })
}
}

View File

@@ -65,7 +65,7 @@ pub(super) async fn authenticate<T: AsRef<str>>(
stream.write(SaslInitialResponse(&client_first_message));
stream.flush().await?;
let server_first_message = stream.read().await?;
let server_first_message = stream.receive().await?;
if let Message::Authentication = server_first_message {
let auth = Authentication::read(stream.buffer())?;
@@ -140,7 +140,7 @@ pub(super) async fn authenticate<T: AsRef<str>>(
stream.write(SaslResponse(&client_final_message));
stream.flush().await?;
let _server_final_response = stream.read().await?;
let _server_final_response = stream.receive().await?;
// todo: assert that this was SaslFinal?
Ok(())

View File

@@ -2,14 +2,17 @@ use std::convert::TryInto;
use std::net::Shutdown;
use byteorder::NetworkEndian;
use futures_channel::mpsc::UnboundedSender;
use crate::io::{Buf, BufStream, MaybeTlsStream};
use crate::postgres::protocol::{Message, Response, Write};
use crate::postgres::protocol::{Message, NotificationResponse, Response, Write};
use crate::postgres::PgError;
use crate::url::Url;
use futures_util::SinkExt;
pub struct PgStream {
pub(super) stream: BufStream<MaybeTlsStream>,
pub(super) notifications: Option<UnboundedSender<NotificationResponse<'static>>>,
// Most recently received message
// Is referenced by our buffered stream
@@ -22,6 +25,7 @@ impl PgStream {
let stream = MaybeTlsStream::connect(&url, 5432).await?;
Ok(Self {
notifications: None,
stream: BufStream::new(stream),
message: (Message::ReadyForQuery, 0),
})
@@ -45,30 +49,36 @@ impl PgStream {
}
pub(super) async fn read(&mut self) -> crate::Result<Message> {
// https://www.postgresql.org/docs/12/protocol-overview.html#PROTOCOL-MESSAGE-CONCEPTS
// All communication is through a stream of messages. The first byte of a message
// identifies the message type, and the next four bytes specify the length of the rest of
// the message (this length count includes itself, but not the message-type byte).
if self.message.1 > 0 {
// If there is any data in our read buffer we need to make sure we flush that
// so reading will return the *next* message
self.stream.consume(self.message.1 as usize);
}
let mut header = self.stream.peek(4 + 1).await?;
let type_ = header.get_u8()?.try_into()?;
let length = header.get_u32::<NetworkEndian>()? - 4;
self.message = (type_, length);
self.stream.consume(4 + 1);
// Wait until there is enough data in the stream. We then return without actually
// inspecting the data. This is then looked at later through the [buffer] function
let _ = self.stream.peek(length as usize).await?;
Ok(type_)
}
pub(super) async fn receive(&mut self) -> crate::Result<Message> {
loop {
// https://www.postgresql.org/docs/12/protocol-overview.html#PROTOCOL-MESSAGE-CONCEPTS
// All communication is through a stream of messages. The first byte of a message
// identifies the message type, and the next four bytes specify the length of the rest of
// the message (this length count includes itself, but not the message-type byte).
if self.message.1 > 0 {
// If there is any data in our read buffer we need to make sure we flush that
// so reading will return the *next* message
self.stream.consume(self.message.1 as usize);
}
let mut header = self.stream.peek(4 + 1).await?;
let type_ = header.get_u8()?.try_into()?;
let length = header.get_u32::<NetworkEndian>()? - 4;
self.message = (type_, length);
self.stream.consume(4 + 1);
// Wait until there is enough data in the stream. We then return without actually
// inspecting the data. This is then looked at later through the [buffer] function
let _ = self.stream.peek(length as usize).await?;
let type_ = self.read().await?;
match type_ {
Message::ErrorResponse | Message::NoticeResponse => {
@@ -84,10 +94,19 @@ impl PgStream {
continue;
}
_ => {
return Ok(type_);
Message::NotificationResponse => {
if let Some(buffer) = &mut self.notifications {
let notification = NotificationResponse::read(self.stream.buffer())?;
let _ = buffer.send(notification.into_owned()).await;
continue;
}
}
_ => {}
}
return Ok(type_);
}
}