use std::convert::TryInto; use std::net::Shutdown; use byteorder::NetworkEndian; use futures_channel::mpsc::UnboundedSender; use crate::io::{Buf, BufStream, MaybeTlsStream}; use crate::postgres::protocol::{Message, NotificationResponse, Response, Write}; use crate::postgres::PgError; use crate::url::Url; use futures_util::SinkExt; pub struct PgStream { pub(super) stream: BufStream, pub(super) notifications: Option>>, // Most recently received message // Is referenced by our buffered stream // Is initialized to ReadyForQuery/0 at the start pub(super) message: (Message, u32), } impl PgStream { pub(super) async fn new(url: &Url) -> crate::Result { let host = url.host(); let port = url.port(5432); #[cfg(unix)] let stream = { let host = host .map(|host| { percent_encoding::percent_decode_str(host) .decode_utf8() .expect("percent-encoded hostname contained non-UTF-8 bytes") }) .or_else(|| url.param("host")) .unwrap_or("/var/run/postgresql".into()); if host.starts_with("/") { let path = format!("{}/.s.PGSQL.{}", host, port); MaybeTlsStream::connect_uds(&path).await? } else { MaybeTlsStream::connect(&host, port).await? } }; #[cfg(not(unix))] let stream = MaybeTlsStream::connect(host.unwrap_or("localhost"), port).await?; Ok(Self { notifications: None, stream: BufStream::new(stream), message: (Message::ReadyForQuery, 0), }) } pub(super) fn shutdown(&self) -> crate::Result<()> { Ok(self.stream.shutdown(Shutdown::Both)?) } #[inline] pub(super) fn write(&mut self, message: M) where M: Write, { message.write(self.stream.buffer_mut()); } #[inline] pub(super) async fn flush(&mut self) -> crate::Result<()> { Ok(self.stream.flush().await?) } pub(super) async fn read(&mut self) -> crate::Result { // https://www.postgresql.org/docs/12/protocol-overview.html#PROTOCOL-MESSAGE-CONCEPTS // All communication is through a stream of messages. The first byte of a message // identifies the message type, and the next four bytes specify the length of the rest of // the message (this length count includes itself, but not the message-type byte). if self.message.1 > 0 { // If there is any data in our read buffer we need to make sure we flush that // so reading will return the *next* message self.stream.consume(self.message.1 as usize); } let mut header = self.stream.peek(4 + 1).await?; let type_ = header.get_u8()?.try_into()?; let length = header.get_u32::()? - 4; self.message = (type_, length); self.stream.consume(4 + 1); // Wait until there is enough data in the stream. We then return without actually // inspecting the data. This is then looked at later through the [buffer] function let _ = self.stream.peek(length as usize).await?; Ok(type_) } pub(super) async fn receive(&mut self) -> crate::Result { loop { let type_ = self.read().await?; match type_ { Message::ErrorResponse | Message::NoticeResponse => { let response = Response::read(self.stream.buffer())?; if response.severity.is_error() { // This is an error, bubble up as one immediately return Err(crate::Error::Database(Box::new(PgError(response)))); } // TODO: Provide some way of receiving these non-critical // notices from postgres continue; } Message::NotificationResponse => { if let Some(buffer) = &mut self.notifications { let notification = NotificationResponse::read(self.stream.buffer())?; let _ = buffer.send(notification.into_owned()).await; continue; } } _ => {} } return Ok(type_); } } /// Returns a reference to the internally buffered message. /// /// This is the body of the message identified by the most recent call /// to `read`. #[inline] pub(super) fn buffer(&self) -> &[u8] { &self.stream.buffer()[..(self.message.1 as usize)] } }