diff --git a/sqlx-postgres/src/column.rs b/sqlx-postgres/src/column.rs index fc192c85..5cf9e1f8 100644 --- a/sqlx-postgres/src/column.rs +++ b/sqlx-postgres/src/column.rs @@ -1,7 +1,7 @@ use std::num::{NonZeroI16, NonZeroI32}; use bytestring::ByteString; -use sqlx_core::{Column, Database}; +use sqlx_core::Column; use crate::protocol::backend::Field; use crate::{PgRawValueFormat, PgTypeId, PgTypeInfo, Postgres}; diff --git a/sqlx-postgres/src/connection.rs b/sqlx-postgres/src/connection.rs index 7c75f959..01649adb 100644 --- a/sqlx-postgres/src/connection.rs +++ b/sqlx-postgres/src/connection.rs @@ -1,7 +1,7 @@ use std::fmt::{self, Debug, Formatter}; #[cfg(feature = "async")] -use futures_util::future::{BoxFuture, FutureExt, TryFutureExt}; +use futures_util::future::{BoxFuture, FutureExt}; use sqlx_core::net::Stream as NetStream; use sqlx_core::{Close, Connect, Connection, Runtime}; @@ -20,6 +20,9 @@ mod executor; pub struct PgConnection { stream: PgStream, + // next statement identifier + next_statement_id: u32, + // number of commands that have been executed // and have yet to see their completion acknowledged // in other words, the number of messages @@ -57,6 +60,7 @@ impl PgConnection { secret_key: 0, transaction_status: TransactionStatus::Idle, pending_ready_for_query_count: 0, + next_statement_id: 1, } } } diff --git a/sqlx-postgres/src/connection/connect.rs b/sqlx-postgres/src/connection/connect.rs index 172c2fb7..eec53705 100644 --- a/sqlx-postgres/src/connection/connect.rs +++ b/sqlx-postgres/src/connection/connect.rs @@ -13,7 +13,7 @@ //! use sqlx_core::net::Stream as NetStream; -use sqlx_core::{Error, Result, Runtime}; +use sqlx_core::{Result, Runtime}; use crate::protocol::backend::{Authentication, BackendMessage, BackendMessageType, KeyData}; use crate::protocol::frontend::{Password, PasswordMd5, Startup}; diff --git a/sqlx-postgres/src/connection/executor.rs b/sqlx-postgres/src/connection/executor.rs index c61ce6bb..229f58f8 100644 --- a/sqlx-postgres/src/connection/executor.rs +++ b/sqlx-postgres/src/connection/executor.rs @@ -1,5 +1,5 @@ #[cfg(feature = "async")] -use futures_util::{future::BoxFuture, FutureExt}; +use futures_util::future::BoxFuture; use sqlx_core::{Execute, Executor, Result, Runtime}; use crate::protocol::backend::ReadyForQuery; diff --git a/sqlx-postgres/src/connection/executor/execute.rs b/sqlx-postgres/src/connection/executor/execute.rs index 75b36302..b871f44b 100644 --- a/sqlx-postgres/src/connection/executor/execute.rs +++ b/sqlx-postgres/src/connection/executor/execute.rs @@ -1,4 +1,4 @@ -use sqlx_core::{Error, Execute, Result, Runtime}; +use sqlx_core::{Execute, Result, Runtime}; use crate::protocol::backend::{BackendMessage, BackendMessageType}; use crate::{PgClientError, PgConnection, PgQueryResult, Postgres}; @@ -10,6 +10,8 @@ impl PgConnection { result: &mut PgQueryResult, ) -> Result { match message.ty { + BackendMessageType::BindComplete => {} + // ignore rows received or metadata about them // TODO: should we log a warning? its wasteful to use `execute` on a query // that does return rows diff --git a/sqlx-postgres/src/connection/executor/fetch_all.rs b/sqlx-postgres/src/connection/executor/fetch_all.rs index 09ca9a64..0a345a18 100644 --- a/sqlx-postgres/src/connection/executor/fetch_all.rs +++ b/sqlx-postgres/src/connection/executor/fetch_all.rs @@ -1,10 +1,9 @@ use std::sync::Arc; -use sqlx_core::io::Deserialize; -use sqlx_core::{Error, Execute, Result, Runtime}; +use sqlx_core::{Execute, Result, Runtime}; -use crate::protocol::backend::{BackendMessage, BackendMessageType, ReadyForQuery, RowDescription}; -use crate::{PgClientError, PgColumn, PgConnection, PgQueryResult, PgRow, Postgres}; +use crate::protocol::backend::{BackendMessage, BackendMessageType, RowDescription}; +use crate::{PgClientError, PgColumn, PgConnection, PgRow, Postgres}; impl PgConnection { fn handle_message_in_fetch_all( @@ -14,6 +13,8 @@ impl PgConnection { columns: &mut Option>, ) -> Result { match message.ty { + BackendMessageType::BindComplete => {} + BackendMessageType::DataRow => { rows.push(PgRow::new(message.deserialize()?, &columns)); } diff --git a/sqlx-postgres/src/connection/executor/fetch_optional.rs b/sqlx-postgres/src/connection/executor/fetch_optional.rs index eadbc7bf..cbb498fe 100644 --- a/sqlx-postgres/src/connection/executor/fetch_optional.rs +++ b/sqlx-postgres/src/connection/executor/fetch_optional.rs @@ -1,10 +1,9 @@ use std::sync::Arc; -use sqlx_core::io::Deserialize; -use sqlx_core::{Error, Execute, Result, Runtime}; +use sqlx_core::{Execute, Result, Runtime}; -use crate::protocol::backend::{BackendMessage, BackendMessageType, ReadyForQuery, RowDescription}; -use crate::{PgClientError, PgColumn, PgConnection, PgQueryResult, PgRow, Postgres}; +use crate::protocol::backend::{BackendMessage, BackendMessageType, RowDescription}; +use crate::{PgClientError, PgColumn, PgConnection, PgRow, Postgres}; impl PgConnection { fn handle_message_in_fetch_optional( @@ -14,6 +13,8 @@ impl PgConnection { columns: &mut Option>, ) -> Result { match message.ty { + BackendMessageType::BindComplete => {} + BackendMessageType::DataRow => { debug_assert!(first_row.is_none()); diff --git a/sqlx-postgres/src/connection/executor/raw_prepare.rs b/sqlx-postgres/src/connection/executor/raw_prepare.rs index e69de29b..5edee483 100644 --- a/sqlx-postgres/src/connection/executor/raw_prepare.rs +++ b/sqlx-postgres/src/connection/executor/raw_prepare.rs @@ -0,0 +1,129 @@ +use sqlx_core::{Result, Runtime}; + +use crate::protocol::backend::{ + BackendMessage, BackendMessageType, ParameterDescription, RowDescription, +}; +use crate::protocol::frontend::{Describe, Parse, StatementRef, Sync, Target}; +use crate::raw_statement::RawStatement; +use crate::{PgArguments, PgClientError, PgConnection}; + +impl PgConnection { + fn start_raw_prepare( + &mut self, + sql: &str, + arguments: &PgArguments<'_>, + ) -> Result { + let statement_id = self.next_statement_id; + self.next_statement_id = self.next_statement_id.wrapping_add(1); + + let statement = RawStatement::new(statement_id); + + self.stream.write_message(&Parse { + statement: StatementRef::Named(statement.id), + sql, + arguments, + })?; + + self.stream.write_message(&Describe { + target: Target::Statement(StatementRef::Named(statement.id)), + })?; + + self.stream.write_message(&Sync)?; + + self.pending_ready_for_query_count += 1; + + Ok(statement) + } + + fn handle_message_in_raw_prepare( + &mut self, + message: BackendMessage, + statement: &mut RawStatement, + ) -> Result { + match message.ty { + BackendMessageType::ParseComplete => { + // next message should be + } + + BackendMessageType::ReadyForQuery => { + self.handle_ready_for_query(message.deserialize()?); + + return Ok(true); + } + + BackendMessageType::ParameterDescription => { + let pd: ParameterDescription = message.deserialize()?; + statement.parameters = pd.parameters; + } + + BackendMessageType::RowDescription => { + let rd: RowDescription = message.deserialize()?; + statement.columns = rd.columns; + } + + ty => { + return Err(PgClientError::UnexpectedMessageType { + ty: ty as u8, + context: "preparing a query", + } + .into()); + } + } + + Ok(false) + } +} + +macro_rules! impl_raw_prepare { + ($(@$blocking:ident)? $self:ident, $sql:ident, $arguments:ident) => {{ + let mut statement = $self.start_raw_prepare($sql, $arguments)?; + + loop { + let message = read_message!($(@$blocking)? $self.stream)?; + + if $self.handle_message_in_raw_prepare(message, &mut statement)? { + break; + } + } + + Ok(statement) + }}; +} + +impl super::PgConnection { + #[cfg(feature = "async")] + pub(crate) async fn raw_prepare_async( + &mut self, + sql: &str, + arguments: &PgArguments<'_>, + ) -> Result + where + Rt: sqlx_core::Async, + { + flush!(self); + impl_raw_prepare!(self, sql, arguments) + } + + #[cfg(feature = "blocking")] + pub(crate) fn raw_prepare_blocking( + &mut self, + sql: &str, + arguments: &PgArguments<'_>, + ) -> Result + where + Rt: sqlx_core::blocking::Runtime, + { + flush!(@blocking self); + impl_raw_prepare!(@blocking self, sql, arguments) + } +} + +macro_rules! raw_prepare { + (@blocking $self:ident, $sql:expr, $arguments:expr) => { + $self.raw_prepare_blocking($sql, $arguments)? + }; + + ($self:ident, $sql:expr, $arguments:expr) => { + $self.raw_prepare_async($sql, $arguments).await? + }; +} diff --git a/sqlx-postgres/src/connection/executor/raw_query.rs b/sqlx-postgres/src/connection/executor/raw_query.rs index 57160e55..a0e680bb 100644 --- a/sqlx-postgres/src/connection/executor/raw_query.rs +++ b/sqlx-postgres/src/connection/executor/raw_query.rs @@ -1,19 +1,39 @@ use sqlx_core::{Execute, Result, Runtime}; -use crate::protocol::frontend::Query; -use crate::{PgConnection, PgRawValueFormat, Postgres}; +use crate::protocol::frontend::{self, Bind, PortalRef, Query, StatementRef, Sync}; +use crate::{PgConnection, Postgres}; macro_rules! impl_raw_query { ($(@$blocking:ident)? $self:ident, $query:ident) => {{ - let format = if let Some(arguments) = $query.arguments() { - todo!("prepared query for postgres") + if let Some(arguments) = $query.arguments() { + // prepare the statement for execution + let statement = raw_prepare!($(@$blocking)? $self, $query.sql(), arguments); + + // bind values to the prepared statement + $self.stream.write_message(&Bind { + portal: PortalRef::Unnamed, + statement: StatementRef::Named(statement.id), + arguments, + parameters: &statement.parameters, + })?; + + // describe the bound prepared statement (portal) + $self.stream.write_message(&frontend::Describe { + target: frontend::Target::Portal(PortalRef::Unnamed), + })?; + + // execute the bound prepared statement (portal) + $self.stream.write_message(&frontend::Execute { + portal: PortalRef::Unnamed, + max_rows: 0, + })?; + + // is what closes the extended query invocation and + // issues a + $self.stream.write_message(&Sync)?; } else { // directly execute the query as an unprepared, simple query $self.stream.write_message(&Query { sql: $query.sql() })?; - - // unprepared queries use the TEXT format - // this is a significant waste of bandwidth for large result sets - PgRawValueFormat::Text }; // as we have written a SQL command of some kind to the stream @@ -22,13 +42,13 @@ macro_rules! impl_raw_query { // half-way through, we need to flush the stream until the ReadyForQuery point $self.pending_ready_for_query_count += 1; - Ok(format) + Ok(()) }}; } impl PgConnection { #[cfg(feature = "async")] - pub(super) async fn raw_query_async<'q, 'a, E>(&mut self, query: E) -> Result + pub(super) async fn raw_query_async<'q, 'a, E>(&mut self, query: E) -> Result<()> where Rt: sqlx_core::Async, E: Execute<'q, 'a, Postgres>, @@ -38,7 +58,7 @@ impl PgConnection { } #[cfg(feature = "blocking")] - pub(super) fn raw_query_blocking<'q, 'a, E>(&mut self, query: E) -> Result + pub(super) fn raw_query_blocking<'q, 'a, E>(&mut self, query: E) -> Result<()> where Rt: sqlx_core::blocking::Runtime, E: Execute<'q, 'a, Postgres>, diff --git a/sqlx-postgres/src/connection/flush.rs b/sqlx-postgres/src/connection/flush.rs index 4d1e9f85..69568641 100644 --- a/sqlx-postgres/src/connection/flush.rs +++ b/sqlx-postgres/src/connection/flush.rs @@ -1,6 +1,7 @@ +use sqlx_core::{Error, Result, Runtime}; + use crate::protocol::backend::{BackendMessage, BackendMessageType}; use crate::PgConnection; -use sqlx_core::{Error, Result, Runtime}; impl PgConnection { fn handle_message_in_flush(&mut self, message: BackendMessage) -> Result { diff --git a/sqlx-postgres/src/error/client.rs b/sqlx-postgres/src/error/client.rs index cdf02cba..845057c8 100644 --- a/sqlx-postgres/src/error/client.rs +++ b/sqlx-postgres/src/error/client.rs @@ -4,8 +4,6 @@ use std::str::Utf8Error; use sqlx_core::{ClientError, Error}; -use crate::protocol::backend::BackendMessageType; - #[derive(Debug)] #[non_exhaustive] pub enum PgClientError { diff --git a/sqlx-postgres/src/lib.rs b/sqlx-postgres/src/lib.rs index 1846a0e4..a72d8eb1 100644 --- a/sqlx-postgres/src/lib.rs +++ b/sqlx-postgres/src/lib.rs @@ -34,7 +34,7 @@ mod options; mod output; mod protocol; mod query_result; -// mod raw_statement; +mod raw_statement; mod raw_value; mod row; // mod transaction; @@ -58,5 +58,4 @@ pub use row::PgRow; pub use type_id::PgTypeId; pub use type_info::PgTypeInfo; -// 'a: argument values -pub type PgArguments<'a> = Arguments<'a, Postgres>; +pub type PgArguments<'v> = Arguments<'v, Postgres>; diff --git a/sqlx-postgres/src/protocol/backend/auth.rs b/sqlx-postgres/src/protocol/backend/auth.rs index dcc17ece..d9899e1a 100644 --- a/sqlx-postgres/src/protocol/backend/auth.rs +++ b/sqlx-postgres/src/protocol/backend/auth.rs @@ -1,6 +1,6 @@ use bytes::{Buf, Bytes}; use sqlx_core::io::Deserialize; -use sqlx_core::{Error, Result}; +use sqlx_core::Result; use crate::protocol::backend::{ AuthenticationSasl, AuthenticationSaslContinue, AuthenticationSaslFinal, diff --git a/sqlx-postgres/src/protocol/backend/parameter_description.rs b/sqlx-postgres/src/protocol/backend/parameter_description.rs index f3c22674..ce489097 100644 --- a/sqlx-postgres/src/protocol/backend/parameter_description.rs +++ b/sqlx-postgres/src/protocol/backend/parameter_description.rs @@ -2,9 +2,11 @@ use bytes::{Buf, Bytes}; use sqlx_core::io::Deserialize; use sqlx_core::Result; +use crate::{PgTypeId, PgTypeInfo}; + #[derive(Debug)] pub(crate) struct ParameterDescription { - pub(crate) parameters: Vec, + pub(crate) parameters: Vec, } impl Deserialize<'_> for ParameterDescription { @@ -13,7 +15,7 @@ impl Deserialize<'_> for ParameterDescription { let mut parameters = Vec::with_capacity(cnt as usize); for _ in 0..cnt { - parameters.push(buf.get_u32()); + parameters.push(PgTypeInfo(PgTypeId::Oid(buf.get_u32()))); } Ok(Self { parameters }) diff --git a/sqlx-postgres/src/protocol/backend/parameter_status.rs b/sqlx-postgres/src/protocol/backend/parameter_status.rs index a5b5d63a..83af7c0f 100644 --- a/sqlx-postgres/src/protocol/backend/parameter_status.rs +++ b/sqlx-postgres/src/protocol/backend/parameter_status.rs @@ -1,4 +1,4 @@ -use bytes::{Buf, Bytes}; +use bytes::Bytes; use bytestring::ByteString; use sqlx_core::io::{BufExt, Deserialize}; use sqlx_core::Result; diff --git a/sqlx-postgres/src/protocol/backend/ready_for_query.rs b/sqlx-postgres/src/protocol/backend/ready_for_query.rs index 1709a0f7..a207a08e 100644 --- a/sqlx-postgres/src/protocol/backend/ready_for_query.rs +++ b/sqlx-postgres/src/protocol/backend/ready_for_query.rs @@ -1,6 +1,6 @@ -use bytes::{Buf, Bytes}; +use bytes::Bytes; use sqlx_core::io::Deserialize; -use sqlx_core::{Error, Result}; +use sqlx_core::Result; use crate::PgClientError; diff --git a/sqlx-postgres/src/protocol/frontend/bind.rs b/sqlx-postgres/src/protocol/frontend/bind.rs index b42e59f6..021b0284 100644 --- a/sqlx-postgres/src/protocol/frontend/bind.rs +++ b/sqlx-postgres/src/protocol/frontend/bind.rs @@ -1,16 +1,28 @@ -use sqlx_core::io::{Serialize, WriteExt}; +use std::fmt::{self, Debug, Formatter}; + +use sqlx_core::io::Serialize; use sqlx_core::Result; use crate::io::PgWriteExt; use crate::protocol::frontend::{PortalRef, StatementRef}; -use crate::PgArguments; +use crate::{PgArguments, PgOutput, PgRawValueFormat, PgTypeInfo}; pub(crate) struct Bind<'a> { pub(crate) portal: PortalRef, pub(crate) statement: StatementRef, + pub(crate) parameters: &'a [PgTypeInfo], pub(crate) arguments: &'a PgArguments<'a>, } +impl Debug for Bind<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("Bind") + .field("statement", &self.statement) + .field("portal", &self.portal) + .finish() + } +} + impl Serialize<'_> for Bind<'_> { fn serialize_with(&self, buf: &mut Vec, _: ()) -> Result<()> { buf.push(b'B'); @@ -20,13 +32,43 @@ impl Serialize<'_> for Bind<'_> { // the parameter format codes, each must presently be zero (text) or one (binary) // can use one to indicate that all parameters use that format - write_i16_arr(buf, &[1]); + write_i16_arr(buf, &[PgRawValueFormat::Binary as i16]); - todo!("arguments"); + // note: this should have been checked in parse + debug_assert!(!(self.arguments.len() >= (u16::MAX as usize))); + + // note: named arguments should have been converted to positional before this point + debug_assert_eq!(self.arguments.num_named(), 0); + + buf.extend(&(self.parameters.len() as i16).to_be_bytes()); + + let mut out = PgOutput::new(buf); + let mut args = self.arguments.positional(); + + for param in self.parameters { + // reserve space to write the prefixed length of the value + let offset = out.buffer().len(); + out.buffer().extend_from_slice(&[0; 4]); + + let len = if let Some(argument) = args.next() { + argument.encode(param, &mut out)?; + + // prefixed length does not include the length in the length + // unlike the regular "prefixed length" bytes protocol type + (out.buffer().len() - offset - 4) as i32 + } else { + // if we run out of values, start sending NULL for + // NULL is encoded as a -1 for the length + -1_i32 + }; + + // write the len to the beginning of the value + out.buffer()[offset..(offset + 4)].copy_from_slice(&len.to_be_bytes()); + } // the result format codes, each must presently be zero (text) or one (binary) // can use one to indicate that all results use that format - write_i16_arr(buf, &[1]); + write_i16_arr(buf, &[PgRawValueFormat::Binary as i16]); Ok(()) }) diff --git a/sqlx-postgres/src/protocol/frontend/close.rs b/sqlx-postgres/src/protocol/frontend/close.rs index af7d3162..f3612599 100644 --- a/sqlx-postgres/src/protocol/frontend/close.rs +++ b/sqlx-postgres/src/protocol/frontend/close.rs @@ -1,7 +1,6 @@ -use sqlx_core::io::{Serialize, WriteExt}; +use sqlx_core::io::Serialize; use sqlx_core::Result; -use crate::io::PgWriteExt; use crate::protocol::frontend::Target; #[derive(Debug)] @@ -12,6 +11,6 @@ pub(crate) struct Close { impl Serialize<'_> for Close { fn serialize_with(&self, buf: &mut Vec, _: ()) -> Result<()> { buf.push(b'C'); - buf.write_len_prefixed(|buf| self.target.serialize(buf)) + self.target.serialize(buf) } } diff --git a/sqlx-postgres/src/protocol/frontend/describe.rs b/sqlx-postgres/src/protocol/frontend/describe.rs index 05671952..acfe2924 100644 --- a/sqlx-postgres/src/protocol/frontend/describe.rs +++ b/sqlx-postgres/src/protocol/frontend/describe.rs @@ -1,17 +1,16 @@ -use sqlx_core::io::{Serialize, WriteExt}; +use sqlx_core::io::Serialize; use sqlx_core::Result; -use crate::io::PgWriteExt; use crate::protocol::frontend::Target; #[derive(Debug)] pub(crate) struct Describe { - target: Target, + pub(crate) target: Target, } impl Serialize<'_> for Describe { fn serialize_with(&self, buf: &mut Vec, _: ()) -> Result<()> { buf.push(b'D'); - buf.write_len_prefixed(|buf| self.target.serialize(buf)) + self.target.serialize(buf) } } diff --git a/sqlx-postgres/src/protocol/frontend/flush.rs b/sqlx-postgres/src/protocol/frontend/flush.rs index e90972af..71f09a39 100644 --- a/sqlx-postgres/src/protocol/frontend/flush.rs +++ b/sqlx-postgres/src/protocol/frontend/flush.rs @@ -7,6 +7,7 @@ pub(crate) struct Flush; impl Serialize<'_> for Flush { fn serialize_with(&self, buf: &mut Vec, _: ()) -> Result<()> { buf.push(b'H'); + buf.extend_from_slice(&4_i32.to_be_bytes()); Ok(()) } diff --git a/sqlx-postgres/src/protocol/frontend/parse.rs b/sqlx-postgres/src/protocol/frontend/parse.rs index 58b23aae..1ceaad25 100644 --- a/sqlx-postgres/src/protocol/frontend/parse.rs +++ b/sqlx-postgres/src/protocol/frontend/parse.rs @@ -1,18 +1,16 @@ +use std::fmt::{self, Debug, Formatter}; + use sqlx_core::io::{Serialize, WriteExt}; use sqlx_core::Result; use crate::io::PgWriteExt; -use crate::protocol::frontend::{PortalRef, StatementRef}; +use crate::protocol::frontend::StatementRef; +use crate::{PgArguments, PgTypeId}; -#[derive(Debug)] pub(crate) struct Parse<'a> { pub(crate) statement: StatementRef, pub(crate) sql: &'a str, - - /// The parameter data types specified (could be zero). Note that this is not an - /// indication of the number of parameters that might appear in the query string, - /// only the number that the frontend wants to pre-specify types for. - pub(crate) parameters: &'a [u32], + pub(crate) arguments: &'a PgArguments<'a>, } impl Serialize<'_> for Parse<'_> { @@ -24,11 +22,19 @@ impl Serialize<'_> for Parse<'_> { buf.write_str_nul(self.sql); // TODO: return a proper error - assert!(!(self.parameters.len() >= (u16::MAX as usize))); + assert!(!(self.arguments.len() >= (u16::MAX as usize))); - buf.extend(&(self.parameters.len() as u16).to_be_bytes()); + // note: named arguments should have been converted to positional before this point + debug_assert_eq!(self.arguments.num_named(), 0); + + buf.extend(&(self.arguments.len() as u16).to_be_bytes()); + + for arg in self.arguments.positional() { + let oid = match arg.type_id() { + PgTypeId::Oid(oid) => oid, + PgTypeId::Name(_) => todo!(), + }; - for &oid in self.parameters { buf.extend(&oid.to_be_bytes()); } @@ -36,3 +42,9 @@ impl Serialize<'_> for Parse<'_> { }) } } + +impl Debug for Parse<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("Parse").field("statement", &self.statement).field("sql", &self.sql).finish() + } +} diff --git a/sqlx-postgres/src/protocol/frontend/statement.rs b/sqlx-postgres/src/protocol/frontend/statement.rs index 9c134027..3e5e2b2a 100644 --- a/sqlx-postgres/src/protocol/frontend/statement.rs +++ b/sqlx-postgres/src/protocol/frontend/statement.rs @@ -1,7 +1,7 @@ use sqlx_core::io::Serialize; use sqlx_core::Result; -#[derive(Debug)] +#[derive(Debug, Copy, Clone)] pub(crate) enum StatementRef { Unnamed, Named(u32), diff --git a/sqlx-postgres/src/protocol/frontend/sync.rs b/sqlx-postgres/src/protocol/frontend/sync.rs index 57724c75..db4c6aff 100644 --- a/sqlx-postgres/src/protocol/frontend/sync.rs +++ b/sqlx-postgres/src/protocol/frontend/sync.rs @@ -7,6 +7,7 @@ pub(crate) struct Sync; impl Serialize<'_> for Sync { fn serialize_with(&self, buf: &mut Vec, _: ()) -> Result<()> { buf.push(b'S'); + buf.extend_from_slice(&4_i32.to_be_bytes()); Ok(()) } diff --git a/sqlx-postgres/src/protocol/frontend/target.rs b/sqlx-postgres/src/protocol/frontend/target.rs index d2a84524..75ed23cf 100644 --- a/sqlx-postgres/src/protocol/frontend/target.rs +++ b/sqlx-postgres/src/protocol/frontend/target.rs @@ -18,12 +18,12 @@ impl Serialize<'_> for Target { match self { Self::Portal(portal) => { buf.push(b'P'); - portal.serialize(buf); + portal.serialize(buf)?; } Self::Statement(statement) => { buf.push(b'S'); - statement.serialize(buf); + statement.serialize(buf)?; } } diff --git a/sqlx-postgres/src/protocol/frontend/terminate.rs b/sqlx-postgres/src/protocol/frontend/terminate.rs index 8c7fccdb..ae409af1 100644 --- a/sqlx-postgres/src/protocol/frontend/terminate.rs +++ b/sqlx-postgres/src/protocol/frontend/terminate.rs @@ -9,6 +9,7 @@ pub(crate) struct Terminate; impl Serialize<'_> for Terminate { fn serialize_with(&self, buf: &mut Vec, _: ()) -> Result<()> { buf.push(b'X'); + buf.extend_from_slice(&4_i32.to_be_bytes()); Ok(()) } diff --git a/sqlx-postgres/src/query_result.rs b/sqlx-postgres/src/query_result.rs index a201490e..7f7bae60 100644 --- a/sqlx-postgres/src/query_result.rs +++ b/sqlx-postgres/src/query_result.rs @@ -1,11 +1,10 @@ use std::convert::TryInto; use std::fmt::{self, Debug, Formatter}; -use std::str::Utf8Error; use bytes::Bytes; use bytestring::ByteString; use memchr::memrchr; -use sqlx_core::{Error, QueryResult, Result}; +use sqlx_core::{QueryResult, Result}; use crate::PgClientError; diff --git a/sqlx-postgres/src/raw_statement.rs b/sqlx-postgres/src/raw_statement.rs new file mode 100644 index 00000000..efe468b0 --- /dev/null +++ b/sqlx-postgres/src/raw_statement.rs @@ -0,0 +1,14 @@ +use crate::{PgColumn, PgTypeInfo}; + +#[derive(Debug, Clone)] +pub(crate) struct RawStatement { + pub(crate) id: u32, + pub(crate) columns: Vec, + pub(crate) parameters: Vec, +} + +impl RawStatement { + pub(crate) fn new(id: u32) -> Self { + Self { id, columns: Vec::new(), parameters: Vec::new() } + } +} diff --git a/sqlx-postgres/src/stream.rs b/sqlx-postgres/src/stream.rs index 72f9a8c1..796bd3b8 100644 --- a/sqlx-postgres/src/stream.rs +++ b/sqlx-postgres/src/stream.rs @@ -105,8 +105,10 @@ macro_rules! impl_read_message { // bytes 1..4 will be the length of the message let size = ($self.stream.get(1, 4).get_u32() - 4) as usize; - // read bytes _after_ the header - impl_read_message!($(@$blocking)? @stream $self, 4, size); + if size > 0 { + // read bytes _after_ the header + impl_read_message!($(@$blocking)? @stream $self, 4, size); + } if let Some(message) = $self.read_message(size)? { break message; diff --git a/sqlx-postgres/src/type_id.rs b/sqlx-postgres/src/type_id.rs index f10ef745..aa6b6d2a 100644 --- a/sqlx-postgres/src/type_id.rs +++ b/sqlx-postgres/src/type_id.rs @@ -117,4 +117,8 @@ impl PgTypeId { _ => "UNKNOWN", } } + + pub(crate) const fn is_integer(&self) -> bool { + matches!(*self, Self::SMALLINT | Self::INTEGER | Self::BIGINT) + } } diff --git a/sqlx-postgres/src/types.rs b/sqlx-postgres/src/types.rs index 2a92883e..8ab53059 100644 --- a/sqlx-postgres/src/types.rs +++ b/sqlx-postgres/src/types.rs @@ -1,3 +1,4 @@ mod bool; +mod int; // https://www.postgresql.org/docs/current/datatype.html diff --git a/sqlx-postgres/src/types/bool.rs b/sqlx-postgres/src/types/bool.rs index b8c71496..2df99e85 100644 --- a/sqlx-postgres/src/types/bool.rs +++ b/sqlx-postgres/src/types/bool.rs @@ -2,7 +2,7 @@ use sqlx_core::{decode, encode, Decode, Encode, Type}; use crate::{PgOutput, PgRawValue, PgRawValueFormat, PgTypeId, PgTypeInfo, Postgres}; -// +// https://www.postgresql.org/docs/current/datatype-boolean.html impl Type for bool { fn type_id() -> PgTypeId @@ -14,7 +14,7 @@ impl Type for bool { } impl Encode for bool { - fn encode(&self, ty: &PgTypeInfo, out: &mut PgOutput<'_>) -> encode::Result<()> { + fn encode(&self, _ty: &PgTypeInfo, out: &mut PgOutput<'_>) -> encode::Result<()> { out.buffer().push(*self as u8); Ok(()) diff --git a/sqlx-postgres/src/types/int.rs b/sqlx-postgres/src/types/int.rs new file mode 100644 index 00000000..6608f486 --- /dev/null +++ b/sqlx-postgres/src/types/int.rs @@ -0,0 +1,111 @@ +use std::cmp; +use std::convert::{TryFrom, TryInto}; +use std::error::Error as StdError; +use std::str::FromStr; + +use bytes::Buf; +use sqlx_core::{decode, encode, Decode, Encode, Type}; + +use crate::{PgOutput, PgRawValue, PgRawValueFormat, PgTypeId, PgTypeInfo, Postgres}; + +// https://www.postgresql.org/docs/current/datatype-numeric.html#DATATYPE-INT + +// todo: allow encode/decode across different integer types +// todo: condense with a macro + +// check that the incoming value is not too large or too small +// to fit into the target SQL type +fn ensure_not_too_large_or_too_small(value: i128, ty: &PgTypeInfo) -> encode::Result<()> { + let max: i128 = match ty.id() { + PgTypeId::SMALLINT => i16::MAX as _, + PgTypeId::INTEGER => i32::MAX as _, + PgTypeId::BIGINT => i64::MAX as _, + + // not an integer type + _ => unreachable!(), + }; + + let min: i128 = match ty.id() { + PgTypeId::SMALLINT => i16::MIN as _, + PgTypeId::INTEGER => i32::MIN as _, + PgTypeId::BIGINT => i64::MIN as _, + + // not an integer type + _ => unreachable!(), + }; + + if value > max { + return Err(encode::Error::msg(format!( + "number `{}` too large to fit in SQL type `{}`", + value, + ty.name() + ))); + } + + if value < min { + return Err(encode::Error::msg(format!( + "number `{}` too small to fit in SQL type `{}`", + value, + ty.name() + ))); + } + + Ok(()) +} + +fn decode_int(value: &PgRawValue<'_>) -> decode::Result +where + T: TryFrom + TryFrom + FromStr, + >::Error: 'static + StdError + Send + Sync, + >::Error: 'static + StdError + Send + Sync, + ::Err: 'static + StdError + Send + Sync, +{ + if value.format() == PgRawValueFormat::Text { + return Ok(value.as_str()?.parse()?); + } + + let mut bytes = value.as_bytes()?; + let size = cmp::min(bytes.len(), 8); + + Ok(bytes.get_int(size).try_into()?) +} + +macro_rules! impl_type_int { + ($ty:ty $(: $real:ty)? => $sql:ident) => { + impl Type for $ty { + fn type_id() -> PgTypeId { + PgTypeId::$sql + } + + fn compatible(ty: &PgTypeInfo) -> bool { + ty.id().is_integer() + } + } + + impl Encode for $ty { + fn encode(&self, ty: &PgTypeInfo, out: &mut PgOutput<'_>) -> encode::Result<()> { + ensure_not_too_large_or_too_small((*self $(as $real)?).into(), ty)?; + + out.buffer().extend_from_slice(&self.to_be_bytes()); + + Ok(()) + } + } + + impl<'r> Decode<'r, Postgres> for $ty { + fn decode(value: PgRawValue<'r>) -> decode::Result { + decode_int(&value) + } + } + }; +} + +impl_type_int! { i8 => SMALLINT } +impl_type_int! { i16 => SMALLINT } +impl_type_int! { i32 => INTEGER } +impl_type_int! { i64 => BIGINT } + +impl_type_int! { u8 => SMALLINT } +impl_type_int! { u16 => SMALLINT } +impl_type_int! { u32 => INTEGER } +impl_type_int! { u64 => BIGINT } diff --git a/sqlx/src/postgres.rs b/sqlx/src/postgres.rs index b56ed9a3..85bac5fa 100644 --- a/sqlx/src/postgres.rs +++ b/sqlx/src/postgres.rs @@ -8,6 +8,6 @@ use crate::DefaultRuntime; pub type PgConnection = sqlx_postgres::PgConnection; pub use sqlx_postgres::{ - types, PgColumn, PgQueryResult, PgRawValue, PgRawValueFormat, PgRow, PgTypeId, Postgres, - PgConnectOptions, + types, PgColumn, PgConnectOptions, PgQueryResult, PgRawValue, PgRawValueFormat, PgRow, + PgTypeId, Postgres, };