mirror of
https://github.com/launchbadge/sqlx.git
synced 2026-02-20 06:50:11 +00:00
143 lines
4.7 KiB
Rust
143 lines
4.7 KiB
Rust
use std::convert::TryInto;
|
|
use std::net::Shutdown;
|
|
|
|
use byteorder::NetworkEndian;
|
|
use futures_channel::mpsc::UnboundedSender;
|
|
|
|
use crate::io::{Buf, BufStream, MaybeTlsStream};
|
|
use crate::postgres::protocol::{Message, NotificationResponse, Response, Write};
|
|
use crate::postgres::PgError;
|
|
|
|
use crate::url::Url;
|
|
use futures_util::SinkExt;
|
|
|
|
pub struct PgStream {
|
|
pub(super) stream: BufStream<MaybeTlsStream>,
|
|
pub(super) notifications: Option<UnboundedSender<NotificationResponse<'static>>>,
|
|
|
|
// Most recently received message
|
|
// Is referenced by our buffered stream
|
|
// Is initialized to ReadyForQuery/0 at the start
|
|
pub(super) message: (Message, u32),
|
|
}
|
|
|
|
impl PgStream {
|
|
pub(super) async fn new(url: &Url) -> crate::Result<Self> {
|
|
let host = url.host();
|
|
let port = url.port(5432);
|
|
#[cfg(unix)]
|
|
let stream = {
|
|
let host = host
|
|
.map(|host| {
|
|
percent_encoding::percent_decode_str(host)
|
|
.decode_utf8()
|
|
.expect("percent-encoded hostname contained non-UTF-8 bytes")
|
|
})
|
|
.or_else(|| url.param("host"))
|
|
.unwrap_or("/var/run/postgresql".into());
|
|
if host.starts_with("/") {
|
|
let path = format!("{}/.s.PGSQL.{}", host, port);
|
|
MaybeTlsStream::connect_uds(&path).await?
|
|
} else {
|
|
MaybeTlsStream::connect(&host, port).await?
|
|
}
|
|
};
|
|
#[cfg(not(unix))]
|
|
let stream = MaybeTlsStream::connect(host.unwrap_or("localhost"), port).await?;
|
|
|
|
Ok(Self {
|
|
notifications: None,
|
|
stream: BufStream::new(stream),
|
|
message: (Message::ReadyForQuery, 0),
|
|
})
|
|
}
|
|
|
|
pub(super) fn shutdown(&self) -> crate::Result<()> {
|
|
Ok(self.stream.shutdown(Shutdown::Both)?)
|
|
}
|
|
|
|
#[inline]
|
|
pub(super) fn write<M>(&mut self, message: M)
|
|
where
|
|
M: Write,
|
|
{
|
|
message.write(self.stream.buffer_mut());
|
|
}
|
|
|
|
#[inline]
|
|
pub(super) async fn flush(&mut self) -> crate::Result<()> {
|
|
Ok(self.stream.flush().await?)
|
|
}
|
|
|
|
pub(super) async fn read(&mut self) -> crate::Result<Message> {
|
|
// https://www.postgresql.org/docs/12/protocol-overview.html#PROTOCOL-MESSAGE-CONCEPTS
|
|
|
|
// All communication is through a stream of messages. The first byte of a message
|
|
// identifies the message type, and the next four bytes specify the length of the rest of
|
|
// the message (this length count includes itself, but not the message-type byte).
|
|
|
|
if self.message.1 > 0 {
|
|
// If there is any data in our read buffer we need to make sure we flush that
|
|
// so reading will return the *next* message
|
|
self.stream.consume(self.message.1 as usize);
|
|
}
|
|
|
|
let mut header = self.stream.peek(4 + 1).await?;
|
|
|
|
let type_ = header.get_u8()?.try_into()?;
|
|
let length = header.get_u32::<NetworkEndian>()? - 4;
|
|
|
|
self.message = (type_, length);
|
|
self.stream.consume(4 + 1);
|
|
|
|
// Wait until there is enough data in the stream. We then return without actually
|
|
// inspecting the data. This is then looked at later through the [buffer] function
|
|
let _ = self.stream.peek(length as usize).await?;
|
|
|
|
Ok(type_)
|
|
}
|
|
|
|
pub(super) async fn receive(&mut self) -> crate::Result<Message> {
|
|
loop {
|
|
let type_ = self.read().await?;
|
|
|
|
match type_ {
|
|
Message::ErrorResponse | Message::NoticeResponse => {
|
|
let response = Response::read(self.stream.buffer())?;
|
|
|
|
if response.severity.is_error() {
|
|
// This is an error, bubble up as one immediately
|
|
return Err(crate::Error::Database(Box::new(PgError(response))));
|
|
}
|
|
|
|
// TODO: Provide some way of receiving these non-critical
|
|
// notices from postgres
|
|
continue;
|
|
}
|
|
|
|
Message::NotificationResponse => {
|
|
if let Some(buffer) = &mut self.notifications {
|
|
let notification = NotificationResponse::read(self.stream.buffer())?;
|
|
|
|
let _ = buffer.send(notification.into_owned()).await;
|
|
continue;
|
|
}
|
|
}
|
|
|
|
_ => {}
|
|
}
|
|
|
|
return Ok(type_);
|
|
}
|
|
}
|
|
|
|
/// Returns a reference to the internally buffered message.
|
|
///
|
|
/// This is the body of the message identified by the most recent call
|
|
/// to `read`.
|
|
#[inline]
|
|
pub(super) fn buffer(&self) -> &[u8] {
|
|
&self.stream.buffer()[..(self.message.1 as usize)]
|
|
}
|
|
}
|