diff --git a/examples/quickstart/postgres+async-std/src/main.rs b/examples/quickstart/postgres+async-std/src/main.rs index 652f10ec..ffd8c1d1 100644 --- a/examples/quickstart/postgres+async-std/src/main.rs +++ b/examples/quickstart/postgres+async-std/src/main.rs @@ -1,5 +1,5 @@ use sqlx::postgres::{PgConnectOptions, PgConnection}; -use sqlx::{Connection, Close, ConnectOptions}; +use sqlx::{Close, ConnectOptions, Connection, Executor}; #[async_std::main] async fn main() -> anyhow::Result<()> { @@ -10,7 +10,8 @@ async fn main() -> anyhow::Result<()> { // set a password (perhaps from somewhere else than the rest of the URL) .password("password") // connect to the database (non-blocking) - .connect().await?; + .connect() + .await?; // the following are equivalent to the above: @@ -23,9 +24,11 @@ async fn main() -> anyhow::Result<()> { // when writing a *type*, Rust allows default type parameters // as opposed to writing a *path* where it does not (yet) + let res = conn.execute("SELECT 1").await?; + // ping, this makes sure the server is still there // hopefully it is – we did just connect to it - conn.ping().await?; + // conn.ping().await?; // close the connection explicitly // this kindly informs the database server that we'll be terminating diff --git a/sqlx-postgres/src/column.rs b/sqlx-postgres/src/column.rs index 1edf267c..fc192c85 100644 --- a/sqlx-postgres/src/column.rs +++ b/sqlx-postgres/src/column.rs @@ -1,31 +1,88 @@ +use std::num::{NonZeroI16, NonZeroI32}; + use bytestring::ByteString; use sqlx_core::{Column, Database}; -use crate::{PgTypeInfo, Postgres}; - -// TODO: inherent methods from +use crate::protocol::backend::Field; +use crate::{PgRawValueFormat, PgTypeId, PgTypeInfo, Postgres}; /// Represents a column from a query in Postgres. #[allow(clippy::module_name_repetitions)] #[derive(Debug, Clone)] +#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] pub struct PgColumn { + /// The index of the column in the row where it is from. index: usize, + + /// The name of the column. name: ByteString, - type_info: PgTypeInfo, + + /// The type information for the column's data type. + pub(crate) type_info: PgTypeInfo, + + /// If the column can be identified as a column of a specific table, the object ID of the table. + #[cfg_attr(feature = "offline", serde(skip))] + pub(crate) relation_id: Option, + + /// If the column can be identified as a column of a specific table, the attribute number of the column. + #[cfg_attr(feature = "offline", serde(skip))] + pub(crate) relation_attribute_no: Option, + + /// The type size (see pg_type.typlen). Note that negative values denote variable-width types. + pub(crate) type_size: i16, + + /// The type modifier (see pg_attribute.atttypmod). The meaning of the modifier is type-specific. + pub(crate) type_modifier: i32, + + /// The format code being used for the column. + pub(crate) format: PgRawValueFormat, +} + +impl PgColumn { + pub(crate) fn from_field(index: usize, field: Field) -> Self { + Self { + index, + name: field.name, + type_info: PgTypeInfo(PgTypeId::Oid(field.type_id)), + relation_id: field.relation_id, + relation_attribute_no: field.relation_attribute_no, + type_modifier: field.type_modifier, + type_size: field.type_size, + format: field.format, + } + } + + /// Returns the name of the column. + pub fn name(&self) -> &str { + &self.name + } + + /// Returns the (zero-based) position of the column. + pub const fn index(&self) -> usize { + self.index + } + + /// Returns type information of the column. + pub fn type_info(&self) -> &PgTypeInfo { + &self.type_info + } } impl Column for PgColumn { type Database = Postgres; + #[inline] fn name(&self) -> &str { - &self.name + self.name() } + #[inline] fn index(&self) -> usize { - self.index + self.index() } + #[inline] fn type_info(&self) -> &PgTypeInfo { - &self.type_info + self.type_info() } } diff --git a/sqlx-postgres/src/connection.rs b/sqlx-postgres/src/connection.rs index 35fab1a9..7c75f959 100644 --- a/sqlx-postgres/src/connection.rs +++ b/sqlx-postgres/src/connection.rs @@ -5,16 +5,30 @@ use futures_util::future::{BoxFuture, FutureExt, TryFutureExt}; use sqlx_core::net::Stream as NetStream; use sqlx_core::{Close, Connect, Connection, Runtime}; +use crate::protocol::backend::TransactionStatus; use crate::stream::PgStream; use crate::{PgConnectOptions, Postgres}; +#[macro_use] +mod flush; + mod connect; +mod executor; /// A single connection (also known as a session) to a /// PostgreSQL database server. pub struct PgConnection { stream: PgStream, + // number of commands that have been executed + // and have yet to see their completion acknowledged + // in other words, the number of messages + // we expect before the stream is clear + pending_ready_for_query_count: usize, + + // current transaction status + transaction_status: TransactionStatus, + // process id of this backend // can be used to send cancel requests #[allow(dead_code)] @@ -37,7 +51,13 @@ where impl PgConnection { pub(crate) fn new(stream: NetStream) -> Self { - Self { stream: PgStream::new(stream), process_id: 0, secret_key: 0 } + Self { + stream: PgStream::new(stream), + process_id: 0, + secret_key: 0, + transaction_status: TransactionStatus::Idle, + pending_ready_for_query_count: 0, + } } } diff --git a/sqlx-postgres/src/connection/connect.rs b/sqlx-postgres/src/connection/connect.rs index 6f817c0e..172c2fb7 100644 --- a/sqlx-postgres/src/connection/connect.rs +++ b/sqlx-postgres/src/connection/connect.rs @@ -15,7 +15,7 @@ use sqlx_core::net::Stream as NetStream; use sqlx_core::{Error, Result, Runtime}; -use crate::protocol::backend::{Authentication, BackendMessage, BackendMessageType}; +use crate::protocol::backend::{Authentication, BackendMessage, BackendMessageType, KeyData}; use crate::protocol::frontend::{Password, PasswordMd5, Startup}; use crate::{PgClientError, PgConnectOptions, PgConnection}; @@ -47,7 +47,7 @@ impl PgConnection { match message.ty { BackendMessageType::Authentication => match message.deserialize()? { Authentication::Ok => { - return Ok(true); + // nothing more to do with authentication } Authentication::Md5Password(data) => { @@ -68,11 +68,26 @@ impl PgConnection { Authentication::SaslFinal(_) => todo!("sasl final"), }, + BackendMessageType::ReadyForQuery => { + self.handle_ready_for_query(message.deserialize()?); + + // fully connected + return Ok(true); + } + + BackendMessageType::BackendKeyData => { + let key_data: KeyData = message.deserialize()?; + + self.process_id = key_data.process_id; + self.secret_key = key_data.secret_key; + } + ty => { - return Err(Error::client(PgClientError::UnexpectedMessageType { + return Err(PgClientError::UnexpectedMessageType { ty: ty as u8, context: "starting up", - })); + } + .into()); } } @@ -100,13 +115,15 @@ macro_rules! impl_connect { // to begin a session, a frontend should send a startup message // this is built up of various startup parameters that control the connection self_.write_startup_message($options)?; + self_.pending_ready_for_query_count += 1; // the server then uses this information and the contents of // its configuration files (such as pg_hba.conf) to determine whether the connection is // provisionally acceptable, and what additional // authentication is required (if any). loop { - let message = read_message!($(@$blocking)? self_.stream); + let message = read_message!($(@$blocking)? self_.stream)?; + if self_.handle_startup_response($options, message)? { // complete, successful authentication break; diff --git a/sqlx-postgres/src/connection/executor.rs b/sqlx-postgres/src/connection/executor.rs new file mode 100644 index 00000000..c61ce6bb --- /dev/null +++ b/sqlx-postgres/src/connection/executor.rs @@ -0,0 +1,107 @@ +#[cfg(feature = "async")] +use futures_util::{future::BoxFuture, FutureExt}; +use sqlx_core::{Execute, Executor, Result, Runtime}; + +use crate::protocol::backend::ReadyForQuery; +use crate::{PgConnection, PgQueryResult, PgRow, Postgres}; + +#[macro_use] +mod raw_prepare; + +#[macro_use] +mod raw_query; + +mod execute; +mod fetch_all; +mod fetch_optional; + +impl PgConnection { + pub(crate) fn handle_ready_for_query(&mut self, rq: ReadyForQuery) { + self.transaction_status = rq.transaction_status; + + debug_assert!(self.pending_ready_for_query_count > 0); + self.pending_ready_for_query_count -= 1; + } +} + +impl Executor for PgConnection { + type Database = Postgres; + + #[cfg(feature = "async")] + #[inline] + fn execute<'x, 'e, 'q, 'v, X>(&'e mut self, query: X) -> BoxFuture<'x, Result> + where + Rt: sqlx_core::Async, + X: 'x + Execute<'q, 'v, Postgres>, + 'e: 'x, + 'q: 'x, + 'v: 'x, + { + Box::pin(self.execute_async(query)) + } + + #[cfg(feature = "async")] + #[inline] + fn fetch_all<'x, 'e, 'q, 'v, X>(&'e mut self, query: X) -> BoxFuture<'x, Result>> + where + Rt: sqlx_core::Async, + X: 'x + Execute<'q, 'v, Postgres>, + 'e: 'x, + 'q: 'x, + 'v: 'x, + { + Box::pin(self.fetch_all_async(query)) + } + + #[cfg(feature = "async")] + #[inline] + fn fetch_optional<'x, 'e, 'q, 'v, X>( + &'e mut self, + query: X, + ) -> BoxFuture<'x, Result>> + where + Rt: sqlx_core::Async, + X: 'x + Execute<'q, 'v, Postgres>, + 'e: 'x, + 'q: 'x, + 'v: 'x, + { + Box::pin(self.fetch_optional_async(query)) + } +} + +#[cfg(feature = "blocking")] +impl sqlx_core::blocking::Executor for PgConnection { + #[inline] + fn execute<'x, 'e, 'q, 'v, X>(&'e mut self, query: X) -> Result + where + X: 'x + Execute<'q, 'v, Postgres>, + 'e: 'x, + 'q: 'x, + 'v: 'x, + { + self.execute_blocking(query) + } + + #[inline] + fn fetch_all<'x, 'e, 'q, 'v, X>(&'e mut self, query: X) -> Result> + where + X: 'x + Execute<'q, 'v, Postgres>, + 'e: 'x, + 'q: 'x, + 'v: 'x, + { + self.fetch_all_blocking(query) + } + + #[inline] + fn fetch_optional<'x, 'e, 'q, 'v, X>(&'e mut self, query: X) -> Result> + where + X: 'x + Execute<'q, 'v, Postgres>, + 'e: 'x, + 'q: 'x, + 'v: 'x, + { + self.fetch_optional_blocking(query) + } +} diff --git a/sqlx-postgres/src/connection/executor/execute.rs b/sqlx-postgres/src/connection/executor/execute.rs new file mode 100644 index 00000000..75b36302 --- /dev/null +++ b/sqlx-postgres/src/connection/executor/execute.rs @@ -0,0 +1,81 @@ +use sqlx_core::{Error, Execute, Result, Runtime}; + +use crate::protocol::backend::{BackendMessage, BackendMessageType}; +use crate::{PgClientError, PgConnection, PgQueryResult, Postgres}; + +impl PgConnection { + fn handle_message_in_execute( + &mut self, + message: BackendMessage, + result: &mut PgQueryResult, + ) -> Result { + match message.ty { + // ignore rows received or metadata about them + // TODO: should we log a warning? its wasteful to use `execute` on a query + // that does return rows + BackendMessageType::DataRow | BackendMessageType::RowDescription => {} + + BackendMessageType::CommandComplete => { + // one statement has finished + result.extend(Some(PgQueryResult::parse(message.contents)?)); + } + + BackendMessageType::ReadyForQuery => { + self.handle_ready_for_query(message.deserialize()?); + + // all statements are finished + return Ok(true); + } + + ty => { + return Err(PgClientError::UnexpectedMessageType { + ty: ty as u8, + context: "executing a query [execute]", + } + .into()); + } + } + + Ok(false) + } +} + +macro_rules! impl_execute { + ($(@$blocking:ident)? $self:ident, $query:ident) => {{ + raw_query!($(@$blocking)? $self, $query); + + let mut result = PgQueryResult::default(); + + loop { + let message = read_message!($(@$blocking)? $self.stream)?; + + if $self.handle_message_in_execute(message, &mut result)? { + break; + } + } + + Ok(result) + }}; +} + +impl PgConnection { + #[cfg(feature = "async")] + pub(super) async fn execute_async<'q, 'a, E>(&mut self, query: E) -> Result + where + Rt: sqlx_core::Async, + E: Execute<'q, 'a, Postgres>, + { + flush!(self); + impl_execute!(self, query) + } + + #[cfg(feature = "blocking")] + pub(super) fn execute_blocking<'q, 'a, E>(&mut self, query: E) -> Result + where + Rt: sqlx_core::blocking::Runtime, + E: Execute<'q, 'a, Postgres>, + { + flush!(self); + impl_execute!(@blocking self, query) + } +} diff --git a/sqlx-postgres/src/connection/executor/fetch_all.rs b/sqlx-postgres/src/connection/executor/fetch_all.rs new file mode 100644 index 00000000..09ca9a64 --- /dev/null +++ b/sqlx-postgres/src/connection/executor/fetch_all.rs @@ -0,0 +1,88 @@ +use std::sync::Arc; + +use sqlx_core::io::Deserialize; +use sqlx_core::{Error, Execute, Result, Runtime}; + +use crate::protocol::backend::{BackendMessage, BackendMessageType, ReadyForQuery, RowDescription}; +use crate::{PgClientError, PgColumn, PgConnection, PgQueryResult, PgRow, Postgres}; + +impl PgConnection { + fn handle_message_in_fetch_all( + &mut self, + message: BackendMessage, + rows: &mut Vec, + columns: &mut Option>, + ) -> Result { + match message.ty { + BackendMessageType::DataRow => { + rows.push(PgRow::new(message.deserialize()?, &columns)); + } + + BackendMessageType::RowDescription => { + *columns = Some(message.deserialize::()?.columns.into()); + } + + BackendMessageType::CommandComplete => { + // one statement has finished + } + + BackendMessageType::ReadyForQuery => { + self.handle_ready_for_query(message.deserialize()?); + + // all statements in this query have finished + return Ok(true); + } + + ty => { + return Err(PgClientError::UnexpectedMessageType { + ty: ty as u8, + context: "executing a query [fetch_all]", + } + .into()); + } + } + + Ok(false) + } +} + +macro_rules! impl_fetch_all { + ($(@$blocking:ident)? $self:ident, $query:ident) => {{ + raw_query!($(@$blocking)? $self, $query); + + let mut rows = Vec::with_capacity(10); + let mut columns = None; + + loop { + let message = read_message!($(@$blocking)? $self.stream)?; + + if $self.handle_message_in_fetch_all(message, &mut rows, &mut columns)? { + break; + } + } + + Ok(rows) + }}; +} + +impl PgConnection { + #[cfg(feature = "async")] + pub(super) async fn fetch_all_async<'q, 'a, E>(&mut self, query: E) -> Result> + where + Rt: sqlx_core::Async, + E: Execute<'q, 'a, Postgres>, + { + flush!(self); + impl_fetch_all!(self, query) + } + + #[cfg(feature = "blocking")] + pub(super) fn fetch_all_blocking<'q, 'a, E>(&mut self, query: E) -> Result> + where + Rt: sqlx_core::blocking::Runtime, + E: Execute<'q, 'a, Postgres>, + { + flush!(self); + impl_fetch_all!(@blocking self, query) + } +} diff --git a/sqlx-postgres/src/connection/executor/fetch_optional.rs b/sqlx-postgres/src/connection/executor/fetch_optional.rs new file mode 100644 index 00000000..eadbc7bf --- /dev/null +++ b/sqlx-postgres/src/connection/executor/fetch_optional.rs @@ -0,0 +1,96 @@ +use std::sync::Arc; + +use sqlx_core::io::Deserialize; +use sqlx_core::{Error, Execute, Result, Runtime}; + +use crate::protocol::backend::{BackendMessage, BackendMessageType, ReadyForQuery, RowDescription}; +use crate::{PgClientError, PgColumn, PgConnection, PgQueryResult, PgRow, Postgres}; + +impl PgConnection { + fn handle_message_in_fetch_optional( + &mut self, + message: BackendMessage, + first_row: &mut Option, + columns: &mut Option>, + ) -> Result { + match message.ty { + BackendMessageType::DataRow => { + debug_assert!(first_row.is_none()); + + *first_row = Some(PgRow::new(message.deserialize()?, &columns)); + + // exit early, we have 1 row + return Ok(true); + } + + BackendMessageType::RowDescription => { + *columns = Some(message.deserialize::()?.columns.into()); + } + + BackendMessageType::CommandComplete => { + // one statement has finished + } + + BackendMessageType::ReadyForQuery => { + self.handle_ready_for_query(message.deserialize()?); + + // all statements in this query have finished + return Ok(true); + } + + ty => { + return Err(PgClientError::UnexpectedMessageType { + ty: ty as u8, + context: "executing a query [fetch_optional]", + } + .into()); + } + } + + Ok(false) + } +} + +macro_rules! impl_fetch_optional { + ($(@$blocking:ident)? $self:ident, $query:ident) => {{ + raw_query!($(@$blocking)? $self, $query); + + let mut first_row = None; + let mut columns = None; + + loop { + let message = read_message!($(@$blocking)? $self.stream)?; + + if $self.handle_message_in_fetch_optional(message, &mut first_row, &mut columns)? { + break; + } + } + + Ok(first_row) + }}; +} + +impl PgConnection { + #[cfg(feature = "async")] + pub(super) async fn fetch_optional_async<'q, 'a, E>( + &mut self, + query: E, + ) -> Result> + where + Rt: sqlx_core::Async, + E: Execute<'q, 'a, Postgres>, + { + flush!(self); + impl_fetch_optional!(self, query) + } + + #[cfg(feature = "blocking")] + pub(super) fn fetch_optional_blocking<'q, 'a, E>(&mut self, query: E) -> Result> + where + Rt: sqlx_core::blocking::Runtime, + E: Execute<'q, 'a, Postgres>, + { + flush!(self); + impl_fetch_optional!(@blocking self, query) + } +} diff --git a/sqlx-postgres/src/connection/executor/raw_prepare.rs b/sqlx-postgres/src/connection/executor/raw_prepare.rs new file mode 100644 index 00000000..e69de29b diff --git a/sqlx-postgres/src/connection/executor/raw_query.rs b/sqlx-postgres/src/connection/executor/raw_query.rs new file mode 100644 index 00000000..57160e55 --- /dev/null +++ b/sqlx-postgres/src/connection/executor/raw_query.rs @@ -0,0 +1,59 @@ +use sqlx_core::{Execute, Result, Runtime}; + +use crate::protocol::frontend::Query; +use crate::{PgConnection, PgRawValueFormat, Postgres}; + +macro_rules! impl_raw_query { + ($(@$blocking:ident)? $self:ident, $query:ident) => {{ + let format = if let Some(arguments) = $query.arguments() { + todo!("prepared query for postgres") + } else { + // directly execute the query as an unprepared, simple query + $self.stream.write_message(&Query { sql: $query.sql() })?; + + // unprepared queries use the TEXT format + // this is a significant waste of bandwidth for large result sets + PgRawValueFormat::Text + }; + + // as we have written a SQL command of some kind to the stream + // we now expect there to be an eventual ReadyForQuery + // if for some reason the future for one of the execution methods is dropped + // half-way through, we need to flush the stream until the ReadyForQuery point + $self.pending_ready_for_query_count += 1; + + Ok(format) + }}; +} + +impl PgConnection { + #[cfg(feature = "async")] + pub(super) async fn raw_query_async<'q, 'a, E>(&mut self, query: E) -> Result + where + Rt: sqlx_core::Async, + E: Execute<'q, 'a, Postgres>, + { + flush!(self); + impl_raw_query!(self, query) + } + + #[cfg(feature = "blocking")] + pub(super) fn raw_query_blocking<'q, 'a, E>(&mut self, query: E) -> Result + where + Rt: sqlx_core::blocking::Runtime, + E: Execute<'q, 'a, Postgres>, + { + flush!(@blocking self); + impl_raw_query!(@blocking self, query) + } +} + +macro_rules! raw_query { + (@blocking $self:ident, $sql:expr) => { + $self.raw_query_blocking($sql)? + }; + + ($self:ident, $sql:expr) => { + $self.raw_query_async($sql).await? + }; +} diff --git a/sqlx-postgres/src/connection/flush.rs b/sqlx-postgres/src/connection/flush.rs new file mode 100644 index 00000000..4d1e9f85 --- /dev/null +++ b/sqlx-postgres/src/connection/flush.rs @@ -0,0 +1,78 @@ +use crate::protocol::backend::{BackendMessage, BackendMessageType}; +use crate::PgConnection; +use sqlx_core::{Error, Result, Runtime}; + +impl PgConnection { + fn handle_message_in_flush(&mut self, message: BackendMessage) -> Result { + match message.ty { + BackendMessageType::ReadyForQuery => { + self.handle_ready_for_query(message.deserialize()?); + + return Ok(true); + } + + _ => {} + } + + Ok(false) + } +} + +macro_rules! impl_flush { + ($(@$blocking:ident)? $self:ident) => {{ + while $self.pending_ready_for_query_count > 0 { + loop { + let message = read_message!($(@$blocking)? $self.stream); + + match message { + Ok(message) => { + if $self.handle_message_in_flush(message)? { + break; + } + } + + Err(error) => { + if matches!(error, Error::Database(_)) { + // log database errors instead of failing on them + // during a flush + log::error!("{}", error); + } else { + return Err(error); + } + } + } + + } + } + + Ok(()) + }}; +} + +impl PgConnection { + #[cfg(feature = "async")] + pub(super) async fn flush_async(&mut self) -> Result<()> + where + Rt: sqlx_core::Async, + { + impl_flush!(self) + } + + #[cfg(feature = "blocking")] + pub(super) fn flush_blocking(&mut self) -> Result<()> + where + Rt: sqlx_core::blocking::Runtime, + { + impl_flush!(@blocking self) + } +} + +macro_rules! flush { + (@blocking $self:ident) => { + $self.flush_blocking()? + }; + + ($self:ident) => { + $self.flush_async().await? + }; +} diff --git a/sqlx-postgres/src/error/client.rs b/sqlx-postgres/src/error/client.rs index 3a9229a3..cdf02cba 100644 --- a/sqlx-postgres/src/error/client.rs +++ b/sqlx-postgres/src/error/client.rs @@ -2,7 +2,7 @@ use std::error::Error as StdError; use std::fmt::{self, Display, Formatter}; use std::str::Utf8Error; -use sqlx_core::ClientError; +use sqlx_core::{ClientError, Error}; use crate::protocol::backend::BackendMessageType; @@ -15,6 +15,7 @@ pub enum PgClientError { UnknownAuthenticationMethod(u32), UnknownMessageType(u8), UnknownTransactionStatus(u8), + UnknownValueFormat(i16), UnexpectedMessageType { ty: u8, context: &'static str }, } @@ -31,6 +32,10 @@ impl Display for PgClientError { write!(f, "in ReadyForQuery, unknown transaction status: {}", status) } + Self::UnknownValueFormat(format) => { + write!(f, "unknown value format: {}", format) + } + Self::UnknownMessageType(ty) => { write!(f, "unknown protocol message type: '{}' ({})", *ty as char, *ty) } @@ -45,3 +50,9 @@ impl Display for PgClientError { impl StdError for PgClientError {} impl ClientError for PgClientError {} + +impl From for Error { + fn from(err: PgClientError) -> Error { + Error::client(err) + } +} diff --git a/sqlx-postgres/src/protocol/backend.rs b/sqlx-postgres/src/protocol/backend.rs index 2881a597..0e7c533d 100644 --- a/sqlx-postgres/src/protocol/backend.rs +++ b/sqlx-postgres/src/protocol/backend.rs @@ -14,6 +14,6 @@ pub(crate) use key_data::KeyData; pub(crate) use message::{BackendMessage, BackendMessageType}; pub(crate) use parameter_description::ParameterDescription; pub(crate) use parameter_status::ParameterStatus; -pub(crate) use ready_for_query::ReadyForQuery; -pub(crate) use row_description::RowDescription; +pub(crate) use ready_for_query::{ReadyForQuery, TransactionStatus}; +pub(crate) use row_description::{Field, RowDescription}; pub(crate) use sasl::{AuthenticationSasl, AuthenticationSaslContinue, AuthenticationSaslFinal}; diff --git a/sqlx-postgres/src/protocol/backend/auth.rs b/sqlx-postgres/src/protocol/backend/auth.rs index 95f2c095..dcc17ece 100644 --- a/sqlx-postgres/src/protocol/backend/auth.rs +++ b/sqlx-postgres/src/protocol/backend/auth.rs @@ -50,7 +50,7 @@ impl Deserialize<'_> for Authentication { 11 => AuthenticationSaslContinue::deserialize(buf).map(Self::SaslContinue), 12 => AuthenticationSaslFinal::deserialize(buf).map(Self::SaslFinal), - ty => Err(Error::client(PgClientError::UnknownAuthenticationMethod(ty))), + ty => Err(PgClientError::UnknownAuthenticationMethod(ty).into()), } } } diff --git a/sqlx-postgres/src/protocol/backend/message.rs b/sqlx-postgres/src/protocol/backend/message.rs index 9559046d..ad55238b 100644 --- a/sqlx-postgres/src/protocol/backend/message.rs +++ b/sqlx-postgres/src/protocol/backend/message.rs @@ -68,7 +68,7 @@ impl TryFrom for BackendMessageType { b'c' => Self::CopyDone, _ => { - return Err(Error::client(PgClientError::UnknownMessageType(ty))); + return Err(PgClientError::UnknownMessageType(ty).into()); } }) } diff --git a/sqlx-postgres/src/protocol/backend/ready_for_query.rs b/sqlx-postgres/src/protocol/backend/ready_for_query.rs index b03b49a5..1709a0f7 100644 --- a/sqlx-postgres/src/protocol/backend/ready_for_query.rs +++ b/sqlx-postgres/src/protocol/backend/ready_for_query.rs @@ -30,7 +30,7 @@ impl Deserialize<'_> for ReadyForQuery { b'E' => TransactionStatus::Error, status => { - return Err(Error::client(PgClientError::UnknownTransactionStatus(status))); + return Err(PgClientError::UnknownTransactionStatus(status).into()); } }; diff --git a/sqlx-postgres/src/protocol/backend/row_description.rs b/sqlx-postgres/src/protocol/backend/row_description.rs index 9d680d0e..610924c4 100644 --- a/sqlx-postgres/src/protocol/backend/row_description.rs +++ b/sqlx-postgres/src/protocol/backend/row_description.rs @@ -5,9 +5,11 @@ use bytestring::ByteString; use sqlx_core::io::{BufExt, Deserialize}; use sqlx_core::Result; +use crate::{PgColumn, PgRawValueFormat}; + #[derive(Debug)] pub(crate) struct RowDescription { - pub(crate) fields: Vec, + pub(crate) columns: Vec, } #[derive(Debug)] @@ -24,45 +26,49 @@ pub(crate) struct Field { pub(crate) relation_attribute_no: Option, /// The object ID of the field's data type. - pub(crate) data_type_id: u32, + pub(crate) type_id: u32, /// The data type size (see pg_type.typlen). Note that negative values denote /// variable-width types. - pub(crate) data_type_size: i16, + pub(crate) type_size: i16, /// The type modifier (see pg_attribute.atttypmod). The meaning of the /// modifier is type-specific. pub(crate) type_modifier: i32, /// The format code being used for the field. - pub(crate) format: i16, + pub(crate) format: PgRawValueFormat, } -impl Deserialize<'_> for RowDescription { +impl<'de> Deserialize<'de> for RowDescription { fn deserialize_with(mut buf: Bytes, _: ()) -> Result { let cnt = buf.get_u16() as usize; - let mut fields = Vec::with_capacity(cnt); - for _ in 0..cnt { + let mut columns = Vec::with_capacity(cnt); + + for index in 0..cnt { let name = buf.get_str_nul()?; let relation_id = buf.get_i32(); let relation_attribute_no = buf.get_i16(); - let data_type_id = buf.get_u32(); - let data_type_size = buf.get_i16(); + let type_id = buf.get_u32(); + let type_size = buf.get_i16(); let type_modifier = buf.get_i32(); let format = buf.get_i16(); - fields.push(Field { - name, - relation_id: NonZeroI32::new(relation_id), - relation_attribute_no: NonZeroI16::new(relation_attribute_no), - data_type_id, - data_type_size, - type_modifier, - format, - }) + columns.push(PgColumn::from_field( + index, + Field { + name, + relation_id: NonZeroI32::new(relation_id), + relation_attribute_no: NonZeroI16::new(relation_attribute_no), + type_id, + type_size, + type_modifier, + format: PgRawValueFormat::from_i16(format)?, + }, + )); } - Ok(Self { fields }) + Ok(Self { columns }) } } diff --git a/sqlx-postgres/src/protocol/frontend/query.rs b/sqlx-postgres/src/protocol/frontend/query.rs index 5027ca39..7df854e8 100644 --- a/sqlx-postgres/src/protocol/frontend/query.rs +++ b/sqlx-postgres/src/protocol/frontend/query.rs @@ -4,16 +4,18 @@ use sqlx_core::Result; use crate::io::PgWriteExt; #[derive(Debug)] -pub(crate) struct Query<'a>(pub(crate) &'a str); +pub(crate) struct Query<'a> { + pub(crate) sql: &'a str, +} impl Serialize<'_> for Query<'_> { fn serialize_with(&self, buf: &mut Vec, _: ()) -> Result<()> { - buf.reserve(1 + self.0.len() + 1 + 4); + buf.reserve(1 + self.sql.len() + 1 + 4); buf.push(b'Q'); buf.write_len_prefixed(|buf| { - buf.extend_from_slice(self.0.as_bytes()); + buf.extend_from_slice(self.sql.as_bytes()); buf.push(0); Ok(()) diff --git a/sqlx-postgres/src/query_result.rs b/sqlx-postgres/src/query_result.rs index 0cd4b4c7..a201490e 100644 --- a/sqlx-postgres/src/query_result.rs +++ b/sqlx-postgres/src/query_result.rs @@ -5,7 +5,9 @@ use std::str::Utf8Error; use bytes::Bytes; use bytestring::ByteString; use memchr::memrchr; -use sqlx_core::QueryResult; +use sqlx_core::{Error, QueryResult, Result}; + +use crate::PgClientError; // TODO: add unit tests for command tag parsing @@ -21,7 +23,7 @@ pub struct PgQueryResult { } impl PgQueryResult { - pub(crate) fn parse(mut command: Bytes) -> Result { + pub(crate) fn parse(mut command: Bytes) -> Result { // look backwards for the first SPACE let offset = memrchr(b' ', &command); @@ -31,7 +33,9 @@ impl PgQueryResult { 0 }; - Ok(Self { command: command.try_into()?, rows_affected: rows }) + let command: ByteString = command.try_into().map_err(PgClientError::NotUtf8)?; + + Ok(Self { command, rows_affected: rows }) } /// Returns the command tag. diff --git a/sqlx-postgres/src/raw_value.rs b/sqlx-postgres/src/raw_value.rs index b5dcf540..5bc1d202 100644 --- a/sqlx-postgres/src/raw_value.rs +++ b/sqlx-postgres/src/raw_value.rs @@ -1,7 +1,10 @@ -use bytes::Bytes; -use sqlx_core::RawValue; +use std::str::from_utf8; -use crate::{PgTypeInfo, Postgres}; +use bytes::Bytes; +use sqlx_core::decode::{Error as DecodeError, Result as DecodeResult}; +use sqlx_core::{RawValue, Result}; + +use crate::{PgClientError, PgTypeInfo, Postgres}; /// The format of a raw SQL value for Postgres. /// @@ -12,9 +15,21 @@ use crate::{PgTypeInfo, Postgres}; /// For simple queries, postgres only can return values in [`Text`] format. /// #[derive(Debug, PartialEq, Copy, Clone)] +#[repr(i16)] pub enum PgRawValueFormat { - Binary, - Text, + Text = 0, + Binary = 1, +} + +impl PgRawValueFormat { + pub(crate) fn from_i16(value: i16) -> Result { + match value { + 0 => Ok(Self::Text), + 1 => Ok(Self::Binary), + + _ => Err(PgClientError::UnknownValueFormat(value).into()), + } + } } /// The raw representation of a SQL value for Postgres. @@ -27,8 +42,15 @@ pub struct PgRawValue<'r> { type_info: PgTypeInfo, } -// 'r: row impl<'r> PgRawValue<'r> { + pub(crate) fn new( + value: &'r Option, + format: PgRawValueFormat, + type_info: PgTypeInfo, + ) -> Self { + Self { value: value.as_ref(), format, type_info } + } + /// Returns the type information for this value. #[must_use] pub const fn type_info(&self) -> &PgTypeInfo { @@ -40,6 +62,17 @@ impl<'r> PgRawValue<'r> { pub const fn format(&self) -> PgRawValueFormat { self.format } + + /// Returns the underlying byte view of this value. + pub fn as_bytes(&self) -> DecodeResult<&'r [u8]> { + self.value.map(|bytes| &**bytes).ok_or(DecodeError::UnexpectedNull) + } + + /// Returns a `&str` slice from the underlying byte view of this value, + /// if it contains valid UTF-8. + pub fn as_str(&self) -> DecodeResult<&'r str> { + self.as_bytes().and_then(|bytes| from_utf8(bytes).map_err(DecodeError::NotUtf8)) + } } impl<'r> RawValue<'r> for PgRawValue<'r> { diff --git a/sqlx-postgres/src/row.rs b/sqlx-postgres/src/row.rs index 673fc0cc..44ada9f8 100644 --- a/sqlx-postgres/src/row.rs +++ b/sqlx-postgres/src/row.rs @@ -1,47 +1,130 @@ -use sqlx_core::{ColumnIndex, Result, Row}; +use std::sync::Arc; +use bytes::Bytes; +use sqlx_core::{ColumnIndex, Result, Row, TypeDecode}; + +use crate::protocol::backend::DataRow; use crate::{PgColumn, PgRawValue, Postgres}; -/// A single row from a result set generated from MySQL. +/// A single result row from a query in PostgreSQL. #[allow(clippy::module_name_repetitions)] -pub struct PgRow {} +pub struct PgRow { + values: Vec>, + columns: Arc<[PgColumn]>, +} + +impl PgRow { + pub(crate) fn new(data: DataRow, columns: &Option>) -> Self { + Self { + values: data.values, + columns: columns.as_ref().map(Arc::clone).unwrap_or_else(|| Arc::new([])), + } + } + + /// Returns `true` if the row contains only `NULL` values. + pub fn is_null(&self) -> bool { + self.values.iter().all(Option::is_some) + } + + /// Returns the number of columns in the row. + #[must_use] + pub fn len(&self) -> usize { + self.values.len() + } + + /// Returns `true` if there are no columns in the row. + #[must_use] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns a reference to the columns in the row. + #[must_use] + pub fn columns(&self) -> &[PgColumn] { + &self.columns + } + + /// Returns the column at the index, if available. + pub fn column>(&self, index: I) -> &PgColumn { + Row::column(self, index) + } + + /// Returns the column at the index, if available. + pub fn try_column>(&self, index: I) -> Result<&PgColumn> { + Ok(&self.columns[index.get(self)?]) + } + + /// Returns the decoded value at the index. + pub fn get<'r, T, I>(&'r self, index: I) -> T + where + I: ColumnIndex, + T: TypeDecode<'r, Postgres>, + { + Row::get(self, index) + } + + /// Returns the decoded value at the index. + pub fn try_get<'r, T, I>(&'r self, index: I) -> Result + where + I: ColumnIndex, + T: TypeDecode<'r, Postgres>, + { + Row::try_get(self, index) + } + + /// Returns the raw representation of the value at the index. + #[allow(clippy::needless_lifetimes)] + pub fn get_raw<'r, I>(&'r self, index: I) -> PgRawValue<'r> + where + I: ColumnIndex, + { + Row::get_raw(self, index) + } + + /// Returns the raw representation of the value at the index. + #[allow(clippy::needless_lifetimes)] + pub fn try_get_raw<'r, I>(&'r self, index: I) -> Result> + where + I: ColumnIndex, + { + let index = index.get(self)?; + + let value = &self.values[index]; + let column = &self.columns[index]; + + Ok(PgRawValue::new(value, column.format, column.type_info)) + } +} impl Row for PgRow { type Database = Postgres; fn is_null(&self) -> bool { - // self.is_null() - todo!() + self.is_null() } fn len(&self) -> usize { - // self.len() - todo!() + self.len() } fn columns(&self) -> &[PgColumn] { - // self.columns() - todo!() + self.columns() } fn try_column>(&self, index: I) -> Result<&PgColumn> { - // self.try_column(index) - todo!() + self.try_column(index) } fn column_name(&self, index: usize) -> Option<&str> { - // self.columns.get(index).map(PgColumn::name) - todo!() + self.columns.get(index).map(PgColumn::name) } fn column_index(&self, name: &str) -> Option { - // self.columns.iter().position(|col| col.name() == name) - todo!() + self.columns.iter().position(|col| col.name() == name) } #[allow(clippy::needless_lifetimes)] fn try_get_raw<'r, I: ColumnIndex>(&'r self, index: I) -> Result> { - // self.try_get_raw(index) - todo!() + self.try_get_raw(index) } } diff --git a/sqlx-postgres/src/stream.rs b/sqlx-postgres/src/stream.rs index ba9ee0b8..72f9a8c1 100644 --- a/sqlx-postgres/src/stream.rs +++ b/sqlx-postgres/src/stream.rs @@ -157,11 +157,11 @@ impl DerefMut for PgStream { macro_rules! read_message { (@blocking $stream:expr) => { - $stream.read_message_blocking()? + $stream.read_message_blocking() }; ($stream:expr) => { - $stream.read_message_async().await? + $stream.read_message_async().await }; } diff --git a/sqlx-postgres/src/types.rs b/sqlx-postgres/src/types.rs index 8b137891..2a92883e 100644 --- a/sqlx-postgres/src/types.rs +++ b/sqlx-postgres/src/types.rs @@ -1 +1,3 @@ +mod bool; +// https://www.postgresql.org/docs/current/datatype.html diff --git a/sqlx-postgres/src/types/bool.rs b/sqlx-postgres/src/types/bool.rs new file mode 100644 index 00000000..b8c71496 --- /dev/null +++ b/sqlx-postgres/src/types/bool.rs @@ -0,0 +1,41 @@ +use sqlx_core::{decode, encode, Decode, Encode, Type}; + +use crate::{PgOutput, PgRawValue, PgRawValueFormat, PgTypeId, PgTypeInfo, Postgres}; + +// + +impl Type for bool { + fn type_id() -> PgTypeId + where + Self: Sized, + { + PgTypeId::BOOLEAN + } +} + +impl Encode for bool { + fn encode(&self, ty: &PgTypeInfo, out: &mut PgOutput<'_>) -> encode::Result<()> { + out.buffer().push(*self as u8); + + Ok(()) + } +} + +impl<'r> Decode<'r, Postgres> for bool { + fn decode(value: PgRawValue<'r>) -> decode::Result { + Ok(match value.format() { + PgRawValueFormat::Binary => value.as_bytes()?[0] != 0, + PgRawValueFormat::Text => match value.as_str()? { + "t" => true, + "f" => false, + + s => { + return Err(decode::Error::msg(format!( + "unexpected value {:?} for `boolean`", + s + ))); + } + }, + }) + } +}