sqlx/sqlx-core/src/postgres/connection.rs
2019-12-27 21:45:31 -08:00

228 lines
8.1 KiB
Rust

use std::convert::TryInto;
use async_std::net::{Shutdown, TcpStream};
use byteorder::NetworkEndian;
use futures_core::future::BoxFuture;
use crate::cache::StatementCache;
use crate::connection::Connection;
use crate::io::{Buf, BufStream};
use crate::postgres::protocol::{self, Decode, Encode, Message, StatementId};
use crate::postgres::PgError;
use crate::url::Url;
pub struct PgConnection {
pub(super) stream: BufStream<TcpStream>,
// Map of query to statement id
pub(super) statement_cache: StatementCache<StatementId>,
// Next statement id
pub(super) next_statement_id: u32,
// Process ID of the Backend
process_id: u32,
// Backend-unique key to use to send a cancel query message to the server
secret_key: u32,
// Is there a query in progress; are we ready to continue
pub(super) ready: bool,
}
impl PgConnection {
// https://www.postgresql.org/docs/12/protocol-flow.html#id-1.10.5.7.3
async fn startup(&mut self, url: Url) -> crate::Result<()> {
// Defaults to postgres@.../postgres
let username = url.username().unwrap_or("postgres");
let database = url.database().unwrap_or("postgres");
// See this doc for more runtime parameters
// https://www.postgresql.org/docs/12/runtime-config-client.html
let params = &[
("user", username),
("database", database),
// Sets the display format for date and time values,
// as well as the rules for interpreting ambiguous date input values.
("DateStyle", "ISO, MDY"),
// Sets the display format for interval values.
("IntervalStyle", "iso_8601"),
// Sets the time zone for displaying and interpreting time stamps.
("TimeZone", "UTC"),
// Adjust postgres to return percise values for floats
// NOTE: This is default in postgres 12+
("extra_float_digits", "3"),
// Sets the client-side encoding (character set).
("client_encoding", "UTF-8"),
];
protocol::StartupMessage { params }.encode(self.stream.buffer_mut());
self.stream.flush().await?;
while let Some(message) = self.receive().await? {
match message {
Message::Authentication(auth) => {
match *auth {
protocol::Authentication::Ok => {
// Do nothing. No password is needed to continue.
}
protocol::Authentication::ClearTextPassword => {
protocol::PasswordMessage::ClearText(
url.password().unwrap_or_default(),
)
.encode(self.stream.buffer_mut());
self.stream.flush().await?;
}
protocol::Authentication::Md5Password { salt } => {
protocol::PasswordMessage::Md5 {
password: url.password().unwrap_or_default(),
user: username,
salt,
}
.encode(self.stream.buffer_mut());
self.stream.flush().await?;
}
auth => {
return Err(protocol_err!(
"requires unimplemented authentication method: {:?}",
auth
)
.into());
}
}
}
Message::BackendKeyData(body) => {
self.process_id = body.process_id;
self.secret_key = body.secret_key;
}
Message::ReadyForQuery(_) => {
// Connection fully established and ready to receive queries.
break;
}
message => {
return Err(protocol_err!("received unexpected message: {:?}", message).into());
}
}
}
Ok(())
}
// https://www.postgresql.org/docs/devel/protocol-flow.html#id-1.10.5.7.10
async fn terminate(mut self) -> crate::Result<()> {
protocol::Terminate.encode(self.stream.buffer_mut());
self.stream.flush().await?;
self.stream.stream.shutdown(Shutdown::Both)?;
Ok(())
}
// Wait and return the next message to be received from Postgres.
pub(super) async fn receive(&mut self) -> crate::Result<Option<Message>> {
loop {
// Read the message header (id + len)
let mut header = ret_if_none!(self.stream.peek(5).await?);
let id = header.get_u8()?;
let len = (header.get_u32::<NetworkEndian>()? - 4) as usize;
// Read the message body
self.stream.consume(5);
let body = ret_if_none!(self.stream.peek(len).await?);
let message = match id {
b'N' | b'E' => Message::Response(Box::new(protocol::Response::decode(body)?)),
b'D' => Message::DataRow(protocol::DataRow::decode(body)?),
b'S' => {
Message::ParameterStatus(Box::new(protocol::ParameterStatus::decode(body)?))
}
b'Z' => Message::ReadyForQuery(protocol::ReadyForQuery::decode(body)?),
b'R' => Message::Authentication(Box::new(protocol::Authentication::decode(body)?)),
b'K' => Message::BackendKeyData(protocol::BackendKeyData::decode(body)?),
b'C' => Message::CommandComplete(protocol::CommandComplete::decode(body)?),
b'A' => Message::NotificationResponse(Box::new(
protocol::NotificationResponse::decode(body)?,
)),
b'1' => Message::ParseComplete,
b'2' => Message::BindComplete,
b'3' => Message::CloseComplete,
b'n' => Message::NoData,
b's' => Message::PortalSuspended,
b't' => Message::ParameterDescription(Box::new(
protocol::ParameterDescription::decode(body)?,
)),
b'T' => Message::RowDescription(Box::new(protocol::RowDescription::decode(body)?)),
id => {
return Err(protocol_err!("received unknown message id: {:?}", id).into());
}
};
self.stream.consume(len);
match message {
Message::ParameterStatus(_body) => {
// TODO: not sure what to do with these yet
}
Message::Response(body) => {
if body.severity.is_error() {
// This is an error, stop the world and bubble as an error
return Err(PgError(body).into());
} else {
// This is a _warning_
// TODO: Log the warning
}
}
message => {
return Ok(Some(message));
}
}
}
}
}
impl PgConnection {
pub(super) async fn open(url: crate::Result<Url>) -> crate::Result<Self> {
let url = url?;
let stream = TcpStream::connect((url.host(), url.port(5432))).await?;
let mut self_ = Self {
stream: BufStream::new(stream),
process_id: 0,
secret_key: 0,
// Important to start at 1 as 0 means "unnamed" in our protocol
next_statement_id: 1,
statement_cache: StatementCache::new(),
ready: true,
};
self_.startup(url).await?;
Ok(self_)
}
}
impl Connection for PgConnection {
fn open<T>(url: T) -> BoxFuture<'static, crate::Result<Self>>
where
T: TryInto<Url, Error = crate::Error>,
Self: Sized,
{
Box::pin(PgConnection::open(url.try_into()))
}
fn close(self) -> BoxFuture<'static, crate::Result<()>> {
Box::pin(self.terminate())
}
}