use std::convert::TryInto; use async_std::net::{Shutdown, TcpStream}; use byteorder::NetworkEndian; use futures_core::future::BoxFuture; use crate::cache::StatementCache; use crate::connection::Connection; use crate::io::{Buf, BufStream}; use crate::postgres::protocol::{self, Decode, Encode, Message, StatementId}; use crate::postgres::PgError; use crate::url::Url; pub struct PgConnection { pub(super) stream: BufStream, // Map of query to statement id pub(super) statement_cache: StatementCache, // Next statement id pub(super) next_statement_id: u32, // Process ID of the Backend process_id: u32, // Backend-unique key to use to send a cancel query message to the server secret_key: u32, // Is there a query in progress; are we ready to continue pub(super) ready: bool, } impl PgConnection { // https://www.postgresql.org/docs/12/protocol-flow.html#id-1.10.5.7.3 async fn startup(&mut self, url: Url) -> crate::Result<()> { // Defaults to postgres@.../postgres let username = url.username().unwrap_or("postgres"); let database = url.database().unwrap_or("postgres"); // See this doc for more runtime parameters // https://www.postgresql.org/docs/12/runtime-config-client.html let params = &[ ("user", username), ("database", database), // Sets the display format for date and time values, // as well as the rules for interpreting ambiguous date input values. ("DateStyle", "ISO, MDY"), // Sets the display format for interval values. ("IntervalStyle", "iso_8601"), // Sets the time zone for displaying and interpreting time stamps. ("TimeZone", "UTC"), // Adjust postgres to return percise values for floats // NOTE: This is default in postgres 12+ ("extra_float_digits", "3"), // Sets the client-side encoding (character set). ("client_encoding", "UTF-8"), ]; protocol::StartupMessage { params }.encode(self.stream.buffer_mut()); self.stream.flush().await?; while let Some(message) = self.receive().await? { match message { Message::Authentication(auth) => { match *auth { protocol::Authentication::Ok => { // Do nothing. No password is needed to continue. } protocol::Authentication::ClearTextPassword => { protocol::PasswordMessage::ClearText( url.password().unwrap_or_default(), ) .encode(self.stream.buffer_mut()); self.stream.flush().await?; } protocol::Authentication::Md5Password { salt } => { protocol::PasswordMessage::Md5 { password: url.password().unwrap_or_default(), user: username, salt, } .encode(self.stream.buffer_mut()); self.stream.flush().await?; } auth => { return Err(protocol_err!( "requires unimplemented authentication method: {:?}", auth ) .into()); } } } Message::BackendKeyData(body) => { self.process_id = body.process_id; self.secret_key = body.secret_key; } Message::ReadyForQuery(_) => { // Connection fully established and ready to receive queries. break; } message => { return Err(protocol_err!("received unexpected message: {:?}", message).into()); } } } Ok(()) } // https://www.postgresql.org/docs/devel/protocol-flow.html#id-1.10.5.7.10 async fn terminate(mut self) -> crate::Result<()> { protocol::Terminate.encode(self.stream.buffer_mut()); self.stream.flush().await?; self.stream.stream.shutdown(Shutdown::Both)?; Ok(()) } // Wait and return the next message to be received from Postgres. pub(super) async fn receive(&mut self) -> crate::Result> { loop { // Read the message header (id + len) let mut header = ret_if_none!(self.stream.peek(5).await?); let id = header.get_u8()?; let len = (header.get_u32::()? - 4) as usize; // Read the message body self.stream.consume(5); let body = ret_if_none!(self.stream.peek(len).await?); let message = match id { b'N' | b'E' => Message::Response(Box::new(protocol::Response::decode(body)?)), b'D' => Message::DataRow(protocol::DataRow::decode(body)?), b'S' => { Message::ParameterStatus(Box::new(protocol::ParameterStatus::decode(body)?)) } b'Z' => Message::ReadyForQuery(protocol::ReadyForQuery::decode(body)?), b'R' => Message::Authentication(Box::new(protocol::Authentication::decode(body)?)), b'K' => Message::BackendKeyData(protocol::BackendKeyData::decode(body)?), b'C' => Message::CommandComplete(protocol::CommandComplete::decode(body)?), b'A' => Message::NotificationResponse(Box::new( protocol::NotificationResponse::decode(body)?, )), b'1' => Message::ParseComplete, b'2' => Message::BindComplete, b'3' => Message::CloseComplete, b'n' => Message::NoData, b's' => Message::PortalSuspended, b't' => Message::ParameterDescription(Box::new( protocol::ParameterDescription::decode(body)?, )), b'T' => Message::RowDescription(Box::new(protocol::RowDescription::decode(body)?)), id => { return Err(protocol_err!("received unknown message id: {:?}", id).into()); } }; self.stream.consume(len); match message { Message::ParameterStatus(_body) => { // TODO: not sure what to do with these yet } Message::Response(body) => { if body.severity.is_error() { // This is an error, stop the world and bubble as an error return Err(PgError(body).into()); } else { // This is a _warning_ // TODO: Log the warning } } message => { return Ok(Some(message)); } } } } } impl PgConnection { pub(super) async fn open(url: crate::Result) -> crate::Result { let url = url?; let stream = TcpStream::connect((url.host(), url.port(5432))).await?; let mut self_ = Self { stream: BufStream::new(stream), process_id: 0, secret_key: 0, // Important to start at 1 as 0 means "unnamed" in our protocol next_statement_id: 1, statement_cache: StatementCache::new(), ready: true, }; self_.startup(url).await?; Ok(self_) } } impl Connection for PgConnection { fn open(url: T) -> BoxFuture<'static, crate::Result> where T: TryInto, Self: Sized, { Box::pin(PgConnection::open(url.try_into())) } fn close(self) -> BoxFuture<'static, crate::Result<()>> { Box::pin(self.terminate()) } }