diff --git a/sqlx-core/src/any/arguments.rs b/sqlx-core/src/any/arguments.rs index ca438b028..7652430c3 100644 --- a/sqlx-core/src/any/arguments.rs +++ b/sqlx-core/src/any/arguments.rs @@ -1,7 +1,8 @@ use crate::any::value::AnyValueKind; use crate::any::Any; use crate::arguments::Arguments; -use crate::encode::Encode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; use crate::types::Type; pub struct AnyArguments<'q> { @@ -16,11 +17,16 @@ impl<'q> Arguments<'q> for AnyArguments<'q> { self.values.0.reserve(additional); } - fn add(&mut self, value: T) + fn add(&mut self, value: T) -> Result<(), BoxDynError> where T: 'q + Encode<'q, Self::Database> + Type, { - let _ = value.encode(&mut self.values); + let _: IsNull = value.encode(&mut self.values)?; + Ok(()) + } + + fn len(&self) -> usize { + self.values.0.len() } } @@ -36,7 +42,7 @@ impl<'q> Default for AnyArguments<'q> { impl<'q> AnyArguments<'q> { #[doc(hidden)] - pub fn convert_to<'a, A: Arguments<'a>>(&'a self) -> A + pub fn convert_to<'a, A: Arguments<'a>>(&'a self) -> Result where 'q: 'a, Option: Type + Encode<'a, A::Database>, @@ -62,9 +68,9 @@ impl<'q> AnyArguments<'q> { AnyValueKind::Double(d) => out.add(d), AnyValueKind::Text(t) => out.add(&**t), AnyValueKind::Blob(b) => out.add(&**b), - } + }? } - out + Ok(out) } } diff --git a/sqlx-core/src/any/connection/executor.rs b/sqlx-core/src/any/connection/executor.rs index 1c10b4e42..a9b0080d1 100644 --- a/sqlx-core/src/any/connection/executor.rs +++ b/sqlx-core/src/any/connection/executor.rs @@ -5,6 +5,8 @@ use crate::executor::{Execute, Executor}; use either::Either; use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; +use futures_util::{stream, FutureExt, StreamExt}; +use std::future; impl<'c> Executor<'c> for &'c mut AnyConnection { type Database = Any; @@ -17,7 +19,10 @@ impl<'c> Executor<'c> for &'c mut AnyConnection { 'c: 'e, E: Execute<'q, Any>, { - let arguments = query.take_arguments(); + let arguments = match query.take_arguments().map_err(Error::Encode) { + Ok(arguments) => arguments, + Err(error) => return stream::once(future::ready(Err(error))).boxed(), + }; self.backend.fetch_many(query.sql(), arguments) } @@ -29,7 +34,10 @@ impl<'c> Executor<'c> for &'c mut AnyConnection { 'c: 'e, E: Execute<'q, Self::Database>, { - let arguments = query.take_arguments(); + let arguments = match query.take_arguments().map_err(Error::Encode) { + Ok(arguments) => arguments, + Err(error) => return future::ready(Err(error)).boxed(), + }; self.backend.fetch_optional(query.sql(), arguments) } diff --git a/sqlx-core/src/any/mod.rs b/sqlx-core/src/any/mod.rs index 56f6dab87..2c0c6649e 100644 --- a/sqlx-core/src/any/mod.rs +++ b/sqlx-core/src/any/mod.rs @@ -67,12 +67,15 @@ impl<'q, T> Encode<'q, Any> for Option where T: Encode<'q, Any> + 'q, { - fn encode_by_ref(&self, buf: &mut AnyArgumentBuffer<'q>) -> crate::encode::IsNull { + fn encode_by_ref( + &self, + buf: &mut AnyArgumentBuffer<'q>, + ) -> Result { if let Some(value) = self { value.encode_by_ref(buf) } else { buf.0.push(AnyValueKind::Null); - crate::encode::IsNull::Yes + Ok(crate::encode::IsNull::Yes) } } } diff --git a/sqlx-core/src/any/types/blob.rs b/sqlx-core/src/any/types/blob.rs index 3f33b746f..851c93bf5 100644 --- a/sqlx-core/src/any/types/blob.rs +++ b/sqlx-core/src/any/types/blob.rs @@ -15,9 +15,12 @@ impl Type for [u8] { } impl<'q> Encode<'q, Any> for &'q [u8] { - fn encode_by_ref(&self, buf: &mut ::ArgumentBuffer<'q>) -> IsNull { + fn encode_by_ref( + &self, + buf: &mut ::ArgumentBuffer<'q>, + ) -> Result { buf.0.push(AnyValueKind::Blob((*self).into())); - IsNull::No + Ok(IsNull::No) } } @@ -42,9 +45,12 @@ impl Type for Vec { } impl<'q> Encode<'q, Any> for Vec { - fn encode_by_ref(&self, buf: &mut ::ArgumentBuffer<'q>) -> IsNull { + fn encode_by_ref( + &self, + buf: &mut ::ArgumentBuffer<'q>, + ) -> Result { buf.0.push(AnyValueKind::Blob(Cow::Owned(self.clone()))); - IsNull::No + Ok(IsNull::No) } } diff --git a/sqlx-core/src/any/types/bool.rs b/sqlx-core/src/any/types/bool.rs index 09f7e1f79..fb7ee9d5d 100644 --- a/sqlx-core/src/any/types/bool.rs +++ b/sqlx-core/src/any/types/bool.rs @@ -14,9 +14,12 @@ impl Type for bool { } impl<'q> Encode<'q, Any> for bool { - fn encode_by_ref(&self, buf: &mut ::ArgumentBuffer<'q>) -> IsNull { + fn encode_by_ref( + &self, + buf: &mut ::ArgumentBuffer<'q>, + ) -> Result { buf.0.push(AnyValueKind::Bool(*self)); - IsNull::No + Ok(IsNull::No) } } diff --git a/sqlx-core/src/any/types/float.rs b/sqlx-core/src/any/types/float.rs index 47b4b24d3..01d6073a2 100644 --- a/sqlx-core/src/any/types/float.rs +++ b/sqlx-core/src/any/types/float.rs @@ -14,9 +14,9 @@ impl Type for f32 { } impl<'q> Encode<'q, Any> for f32 { - fn encode_by_ref(&self, buf: &mut AnyArgumentBuffer<'q>) -> IsNull { + fn encode_by_ref(&self, buf: &mut AnyArgumentBuffer<'q>) -> Result { buf.0.push(AnyValueKind::Real(*self)); - IsNull::No + Ok(IsNull::No) } } @@ -38,9 +38,12 @@ impl Type for f64 { } impl<'q> Encode<'q, Any> for f64 { - fn encode_by_ref(&self, buf: &mut ::ArgumentBuffer<'q>) -> IsNull { + fn encode_by_ref( + &self, + buf: &mut ::ArgumentBuffer<'q>, + ) -> Result { buf.0.push(AnyValueKind::Double(*self)); - IsNull::No + Ok(IsNull::No) } } diff --git a/sqlx-core/src/any/types/int.rs b/sqlx-core/src/any/types/int.rs index ae8d0e71f..56152af14 100644 --- a/sqlx-core/src/any/types/int.rs +++ b/sqlx-core/src/any/types/int.rs @@ -18,9 +18,12 @@ impl Type for i16 { } impl<'q> Encode<'q, Any> for i16 { - fn encode_by_ref(&self, buf: &mut ::ArgumentBuffer<'q>) -> IsNull { + fn encode_by_ref( + &self, + buf: &mut ::ArgumentBuffer<'q>, + ) -> Result { buf.0.push(AnyValueKind::SmallInt(*self)); - IsNull::No + Ok(IsNull::No) } } @@ -43,9 +46,12 @@ impl Type for i32 { } impl<'q> Encode<'q, Any> for i32 { - fn encode_by_ref(&self, buf: &mut ::ArgumentBuffer<'q>) -> IsNull { + fn encode_by_ref( + &self, + buf: &mut ::ArgumentBuffer<'q>, + ) -> Result { buf.0.push(AnyValueKind::Integer(*self)); - IsNull::No + Ok(IsNull::No) } } @@ -68,9 +74,12 @@ impl Type for i64 { } impl<'q> Encode<'q, Any> for i64 { - fn encode_by_ref(&self, buf: &mut ::ArgumentBuffer<'q>) -> IsNull { + fn encode_by_ref( + &self, + buf: &mut ::ArgumentBuffer<'q>, + ) -> Result { buf.0.push(AnyValueKind::BigInt(*self)); - IsNull::No + Ok(IsNull::No) } } diff --git a/sqlx-core/src/any/types/str.rs b/sqlx-core/src/any/types/str.rs index 3ce6d28e5..4c0083269 100644 --- a/sqlx-core/src/any/types/str.rs +++ b/sqlx-core/src/any/types/str.rs @@ -16,15 +16,18 @@ impl Type for str { } impl<'a> Encode<'a, Any> for &'a str { - fn encode(self, buf: &mut ::ArgumentBuffer<'a>) -> IsNull + fn encode(self, buf: &mut ::ArgumentBuffer<'a>) -> Result where Self: Sized, { buf.0.push(AnyValueKind::Text(self.into())); - IsNull::No + Ok(IsNull::No) } - fn encode_by_ref(&self, buf: &mut ::ArgumentBuffer<'a>) -> IsNull { + fn encode_by_ref( + &self, + buf: &mut ::ArgumentBuffer<'a>, + ) -> Result { (*self).encode(buf) } } @@ -50,9 +53,12 @@ impl Type for String { } impl<'q> Encode<'q, Any> for String { - fn encode_by_ref(&self, buf: &mut ::ArgumentBuffer<'q>) -> IsNull { + fn encode_by_ref( + &self, + buf: &mut ::ArgumentBuffer<'q>, + ) -> Result { buf.0.push(AnyValueKind::Text(Cow::Owned(self.clone()))); - IsNull::No + Ok(IsNull::No) } } diff --git a/sqlx-core/src/arguments.rs b/sqlx-core/src/arguments.rs index 61988eb21..8dcc04387 100644 --- a/sqlx-core/src/arguments.rs +++ b/sqlx-core/src/arguments.rs @@ -2,6 +2,7 @@ use crate::database::Database; use crate::encode::Encode; +use crate::error::BoxDynError; use crate::types::Type; use std::fmt::{self, Write}; @@ -14,10 +15,13 @@ pub trait Arguments<'q>: Send + Sized + Default { fn reserve(&mut self, additional: usize, size: usize); /// Add the value to the end of the arguments. - fn add(&mut self, value: T) + fn add(&mut self, value: T) -> Result<(), BoxDynError> where T: 'q + Encode<'q, Self::Database> + Type; + /// The number of arguments that were already added. + fn len(&self) -> usize; + fn format_placeholder(&self, writer: &mut W) -> fmt::Result { writer.write_str("?") } diff --git a/sqlx-core/src/encode.rs b/sqlx-core/src/encode.rs index 0ba186594..2d28641f9 100644 --- a/sqlx-core/src/encode.rs +++ b/sqlx-core/src/encode.rs @@ -3,8 +3,10 @@ use std::mem; use crate::database::Database; +use crate::error::BoxDynError; /// The return type of [Encode::encode]. +#[must_use] pub enum IsNull { /// The value is null; no data was written. Yes, @@ -15,11 +17,16 @@ pub enum IsNull { No, } +impl IsNull { + pub fn is_null(&self) -> bool { + matches!(self, IsNull::Yes) + } +} + /// Encode a single value to be sent to the database. pub trait Encode<'q, DB: Database> { /// Writes the value of `self` into `buf` in the expected format for the database. - #[must_use] - fn encode(self, buf: &mut ::ArgumentBuffer<'q>) -> IsNull + fn encode(self, buf: &mut ::ArgumentBuffer<'q>) -> Result where Self: Sized, { @@ -30,8 +37,10 @@ pub trait Encode<'q, DB: Database> { /// /// Where possible, make use of `encode` instead as it can take advantage of re-using /// memory. - #[must_use] - fn encode_by_ref(&self, buf: &mut ::ArgumentBuffer<'q>) -> IsNull; + fn encode_by_ref( + &self, + buf: &mut ::ArgumentBuffer<'q>, + ) -> Result; fn produces(&self) -> Option { // `produces` is inherently a hook to allow database drivers to produce value-dependent @@ -50,12 +59,15 @@ where T: Encode<'q, DB>, { #[inline] - fn encode(self, buf: &mut ::ArgumentBuffer<'q>) -> IsNull { + fn encode(self, buf: &mut ::ArgumentBuffer<'q>) -> Result { >::encode_by_ref(self, buf) } #[inline] - fn encode_by_ref(&self, buf: &mut ::ArgumentBuffer<'q>) -> IsNull { + fn encode_by_ref( + &self, + buf: &mut ::ArgumentBuffer<'q>, + ) -> Result { <&T as Encode>::encode(self, buf) } @@ -90,11 +102,11 @@ macro_rules! impl_encode_for_option { fn encode( self, buf: &mut <$DB as $crate::database::Database>::ArgumentBuffer<'q>, - ) -> $crate::encode::IsNull { + ) -> Result<$crate::encode::IsNull, $crate::error::BoxDynError> { if let Some(v) = self { v.encode(buf) } else { - $crate::encode::IsNull::Yes + Ok($crate::encode::IsNull::Yes) } } @@ -102,11 +114,11 @@ macro_rules! impl_encode_for_option { fn encode_by_ref( &self, buf: &mut <$DB as $crate::database::Database>::ArgumentBuffer<'q>, - ) -> $crate::encode::IsNull { + ) -> Result<$crate::encode::IsNull, $crate::error::BoxDynError> { if let Some(v) = self { v.encode_by_ref(buf) } else { - $crate::encode::IsNull::Yes + Ok($crate::encode::IsNull::Yes) } } diff --git a/sqlx-core/src/error.rs b/sqlx-core/src/error.rs index 5c4c223fa..042342ef9 100644 --- a/sqlx-core/src/error.rs +++ b/sqlx-core/src/error.rs @@ -78,6 +78,10 @@ pub enum Error { source: BoxDynError, }, + /// Error occured while encoding a value. + #[error("error occured while encoding a value: {0}")] + Encode(#[source] BoxDynError), + /// Error occurred while decoding a value. #[error("error occurred while decoding: {0}")] Decode(#[source] BoxDynError), diff --git a/sqlx-core/src/executor.rs b/sqlx-core/src/executor.rs index e3f245d92..b38b4e723 100644 --- a/sqlx-core/src/executor.rs +++ b/sqlx-core/src/executor.rs @@ -1,6 +1,6 @@ use crate::database::Database; use crate::describe::Describe; -use crate::error::Error; +use crate::error::{BoxDynError, Error}; use either::Either; use futures_core::future::BoxFuture; @@ -199,10 +199,12 @@ pub trait Execute<'q, DB: Database>: Send + Sized { /// Returns the arguments to be bound against the query string. /// - /// Returning `None` for `Arguments` indicates to use a "simple" query protocol and to not - /// prepare the query. Returning `Some(Default::default())` is an empty arguments object that + /// Returning `Ok(None)` for `Arguments` indicates to use a "simple" query protocol and to not + /// prepare the query. Returning `Ok(Some(Default::default()))` is an empty arguments object that /// will be prepared (and cached) before execution. - fn take_arguments(&mut self) -> Option<::Arguments<'q>>; + /// + /// Returns `Err` if encoding any of the arguments failed. + fn take_arguments(&mut self) -> Result::Arguments<'q>>, BoxDynError>; /// Returns `true` if the statement should be cached. fn persistent(&self) -> bool; @@ -222,8 +224,8 @@ impl<'q, DB: Database> Execute<'q, DB> for &'q str { } #[inline] - fn take_arguments(&mut self) -> Option<::Arguments<'q>> { - None + fn take_arguments(&mut self) -> Result::Arguments<'q>>, BoxDynError> { + Ok(None) } #[inline] @@ -244,8 +246,8 @@ impl<'q, DB: Database> Execute<'q, DB> for (&'q str, Option<::Ar } #[inline] - fn take_arguments(&mut self) -> Option<::Arguments<'q>> { - self.1.take() + fn take_arguments(&mut self) -> Result::Arguments<'q>>, BoxDynError> { + Ok(self.1.take()) } #[inline] diff --git a/sqlx-core/src/query.rs b/sqlx-core/src/query.rs index 0b90e0574..0881c9ccf 100644 --- a/sqlx-core/src/query.rs +++ b/sqlx-core/src/query.rs @@ -7,7 +7,7 @@ use futures_util::{future, StreamExt, TryFutureExt, TryStreamExt}; use crate::arguments::{Arguments, IntoArguments}; use crate::database::{Database, HasStatementCache}; use crate::encode::Encode; -use crate::error::Error; +use crate::error::{BoxDynError, Error}; use crate::executor::{Execute, Executor}; use crate::statement::Statement; use crate::types::Type; @@ -16,7 +16,7 @@ use crate::types::Type; #[must_use = "query must be executed to affect database"] pub struct Query<'q, DB: Database, A> { pub(crate) statement: Either<&'q str, &'q DB::Statement<'q>>, - pub(crate) arguments: Option, + pub(crate) arguments: Option>, pub(crate) database: PhantomData, pub(crate) persistent: bool, } @@ -59,8 +59,11 @@ where } #[inline] - fn take_arguments(&mut self) -> Option<::Arguments<'q>> { - self.arguments.take().map(IntoArguments::into_arguments) + fn take_arguments(&mut self) -> Result::Arguments<'q>>, BoxDynError> { + self.arguments + .take() + .transpose() + .map(|option| option.map(IntoArguments::into_arguments)) } #[inline] @@ -78,13 +81,43 @@ impl<'q, DB: Database> Query<'q, DB, ::Arguments<'q>> { /// /// There is no validation that the value is of the type expected by the query. Most SQL /// flavors will perform type coercion (Postgres will return a database error). + /// + /// If encoding the value fails, the error is stored and later surfaced when executing the query. pub fn bind + Type>(mut self, value: T) -> Self { - if let Some(arguments) = &mut self.arguments { - arguments.add(value); + let Ok(arguments) = self.get_arguments() else { + return self; + }; + + let argument_number = arguments.len() + 1; + if let Err(error) = arguments.add(value) { + self.arguments = Some(Err(format!( + "Encoding argument ${argument_number} failed: {error}" + ) + .into())); } self } + + /// Like [`Query::try_bind`] but immediately returns an error if encoding the value failed. + pub fn try_bind + Type>( + &mut self, + value: T, + ) -> Result<(), BoxDynError> { + let arguments = self.get_arguments()?; + + arguments.add(value) + } + + fn get_arguments(&mut self) -> Result<&mut DB::Arguments<'q>, BoxDynError> { + let Some(Ok(arguments)) = self.arguments.as_mut().map(Result::as_mut) else { + return Err("A previous call to Query::bind produced an error" + .to_owned() + .into()); + }; + + Ok(arguments) + } } impl<'q, DB, A> Query<'q, DB, A> @@ -280,7 +313,7 @@ where } #[inline] - fn take_arguments(&mut self) -> Option<::Arguments<'q>> { + fn take_arguments(&mut self) -> Result::Arguments<'q>>, BoxDynError> { self.inner.take_arguments() } @@ -472,7 +505,7 @@ where { Query { database: PhantomData, - arguments: Some(Default::default()), + arguments: Some(Ok(Default::default())), statement: Either::Right(statement), persistent: true, } @@ -489,7 +522,7 @@ where { Query { database: PhantomData, - arguments: Some(arguments), + arguments: Some(Ok(arguments)), statement: Either::Right(statement), persistent: true, } @@ -625,7 +658,7 @@ where { Query { database: PhantomData, - arguments: Some(Default::default()), + arguments: Some(Ok(Default::default())), statement: Either::Left(sql), persistent: true, } @@ -635,6 +668,18 @@ where /// /// See [`query()`][query] for details, such as supported syntax. pub fn query_with<'q, DB, A>(sql: &'q str, arguments: A) -> Query<'q, DB, A> +where + DB: Database, + A: IntoArguments<'q, DB>, +{ + query_with_result(sql, Ok(arguments)) +} + +/// Same as [`query_with`] but is initialized with a Result of arguments instead +pub fn query_with_result<'q, DB, A>( + sql: &'q str, + arguments: Result, +) -> Query<'q, DB, A> where DB: Database, A: IntoArguments<'q, DB>, diff --git a/sqlx-core/src/query_as.rs b/sqlx-core/src/query_as.rs index f84eed8e5..fbc7fab55 100644 --- a/sqlx-core/src/query_as.rs +++ b/sqlx-core/src/query_as.rs @@ -7,10 +7,10 @@ use futures_util::{StreamExt, TryStreamExt}; use crate::arguments::IntoArguments; use crate::database::{Database, HasStatementCache}; use crate::encode::Encode; -use crate::error::Error; +use crate::error::{BoxDynError, Error}; use crate::executor::{Execute, Executor}; use crate::from_row::FromRow; -use crate::query::{query, query_statement, query_statement_with, query_with, Query}; +use crate::query::{query, query_statement, query_statement_with, query_with_result, Query}; use crate::types::Type; /// A single SQL query as a prepared statement, mapping results using [`FromRow`]. @@ -37,7 +37,7 @@ where } #[inline] - fn take_arguments(&mut self) -> Option<::Arguments<'q>> { + fn take_arguments(&mut self) -> Result::Arguments<'q>>, BoxDynError> { self.inner.take_arguments() } @@ -358,13 +358,27 @@ where /// For details about type mapping from [`FromRow`], see [`query_as()`]. #[inline] pub fn query_as_with<'q, DB, O, A>(sql: &'q str, arguments: A) -> QueryAs<'q, DB, O, A> +where + DB: Database, + A: IntoArguments<'q, DB>, + O: for<'r> FromRow<'r, DB::Row>, +{ + query_as_with_result(sql, Ok(arguments)) +} + +/// Same as [`query_as_with`] but takes arguments as a Result +#[inline] +pub fn query_as_with_result<'q, DB, O, A>( + sql: &'q str, + arguments: Result, +) -> QueryAs<'q, DB, O, A> where DB: Database, A: IntoArguments<'q, DB>, O: for<'r> FromRow<'r, DB::Row>, { QueryAs { - inner: query_with(sql, arguments), + inner: query_with_result(sql, arguments), output: PhantomData, } } diff --git a/sqlx-core/src/query_builder.rs b/sqlx-core/src/query_builder.rs index a764ad320..b071ff8a4 100644 --- a/sqlx-core/src/query_builder.rs +++ b/sqlx-core/src/query_builder.rs @@ -155,7 +155,7 @@ where .arguments .as_mut() .expect("BUG: Arguments taken already"); - arguments.add(value); + arguments.add(value).expect("Failed to add argument"); arguments .format_placeholder(&mut self.query) @@ -450,7 +450,7 @@ where Query { statement: Either::Left(&self.query), - arguments: self.arguments.take(), + arguments: self.arguments.take().map(Ok), database: PhantomData, persistent: true, } diff --git a/sqlx-core/src/query_scalar.rs b/sqlx-core/src/query_scalar.rs index 1fcc577d3..c131adcca 100644 --- a/sqlx-core/src/query_scalar.rs +++ b/sqlx-core/src/query_scalar.rs @@ -5,11 +5,11 @@ use futures_util::{StreamExt, TryFutureExt, TryStreamExt}; use crate::arguments::IntoArguments; use crate::database::{Database, HasStatementCache}; use crate::encode::Encode; -use crate::error::Error; +use crate::error::{BoxDynError, Error}; use crate::executor::{Execute, Executor}; use crate::from_row::FromRow; use crate::query_as::{ - query_as, query_as_with, query_statement_as, query_statement_as_with, QueryAs, + query_as, query_as_with_result, query_statement_as, query_statement_as_with, QueryAs, }; use crate::types::Type; @@ -34,7 +34,7 @@ where } #[inline] - fn take_arguments(&mut self) -> Option<::Arguments<'q>> { + fn take_arguments(&mut self) -> Result::Arguments<'q>>, BoxDynError> { self.inner.take_arguments() } @@ -338,13 +338,27 @@ where /// For details about prepared statements and allowed SQL syntax, see [`query()`][crate::query::query]. #[inline] pub fn query_scalar_with<'q, DB, O, A>(sql: &'q str, arguments: A) -> QueryScalar<'q, DB, O, A> +where + DB: Database, + A: IntoArguments<'q, DB>, + (O,): for<'r> FromRow<'r, DB::Row>, +{ + query_scalar_with_result(sql, Ok(arguments)) +} + +/// Same as [`query_scalar_with`] but takes arguments as Result +#[inline] +pub fn query_scalar_with_result<'q, DB, O, A>( + sql: &'q str, + arguments: Result, +) -> QueryScalar<'q, DB, O, A> where DB: Database, A: IntoArguments<'q, DB>, (O,): for<'r> FromRow<'r, DB::Row>, { QueryScalar { - inner: query_as_with(sql, arguments), + inner: query_as_with_result(sql, arguments), } } diff --git a/sqlx-core/src/raw_sql.rs b/sqlx-core/src/raw_sql.rs index 5617bfee0..37627d445 100644 --- a/sqlx-core/src/raw_sql.rs +++ b/sqlx-core/src/raw_sql.rs @@ -2,6 +2,7 @@ use either::Either; use futures_core::stream::BoxStream; use crate::database::Database; +use crate::error::BoxDynError; use crate::executor::{Execute, Executor}; use crate::Error; @@ -126,8 +127,8 @@ impl<'q, DB: Database> Execute<'q, DB> for RawSql<'q> { None } - fn take_arguments(&mut self) -> Option<::Arguments<'q>> { - None + fn take_arguments(&mut self) -> Result::Arguments<'q>>, BoxDynError> { + Ok(None) } fn persistent(&self) -> bool { diff --git a/sqlx-core/src/types/bstr.rs b/sqlx-core/src/types/bstr.rs index ef571a9bf..4b6daadfd 100644 --- a/sqlx-core/src/types/bstr.rs +++ b/sqlx-core/src/types/bstr.rs @@ -37,7 +37,10 @@ where DB: Database, &'q [u8]: Encode<'q, DB>, { - fn encode_by_ref(&self, buf: &mut ::ArgumentBuffer<'q>) -> IsNull { + fn encode_by_ref( + &self, + buf: &mut ::ArgumentBuffer<'q>, + ) -> Result { <&[u8] as Encode>::encode(self.as_bytes(), buf) } } @@ -47,7 +50,10 @@ where DB: Database, Vec: Encode<'q, DB>, { - fn encode_by_ref(&self, buf: &mut ::ArgumentBuffer<'q>) -> IsNull { + fn encode_by_ref( + &self, + buf: &mut ::ArgumentBuffer<'q>, + ) -> Result { as Encode>::encode(self.as_bytes().to_vec(), buf) } } diff --git a/sqlx-core/src/types/json.rs b/sqlx-core/src/types/json.rs index f2b58e70a..5e85092bc 100644 --- a/sqlx-core/src/types/json.rs +++ b/sqlx-core/src/types/json.rs @@ -89,21 +89,15 @@ impl AsMut for Json { } } -const JSON_SERIALIZE_ERR: &str = "failed to encode value as JSON; the most likely cause is \ - attempting to serialize a map with a non-string key type"; - // UNSTABLE: for driver use only! #[doc(hidden)] impl Json { - pub fn encode_to_string(&self) -> String { - // Encoding is supposed to be infallible so we don't have much choice but to panic here. - // However, I believe that's the right thing to do anyway as an object being unable - // to serialize to JSON is likely due to a bug or a malformed datastructure. - serde_json::to_string(self).expect(JSON_SERIALIZE_ERR) + pub fn encode_to_string(&self) -> Result { + serde_json::to_string(self) } - pub fn encode_to(&self, buf: &mut Vec) { - serde_json::to_writer(buf, self).expect(JSON_SERIALIZE_ERR) + pub fn encode_to(&self, buf: &mut Vec) -> Result<(), serde_json::Error> { + serde_json::to_writer(buf, self) } } @@ -141,7 +135,10 @@ where for<'a> Json<&'a Self>: Encode<'q, DB>, DB: Database, { - fn encode_by_ref(&self, buf: &mut ::ArgumentBuffer<'q>) -> IsNull { + fn encode_by_ref( + &self, + buf: &mut ::ArgumentBuffer<'q>, + ) -> Result { as Encode<'q, DB>>::encode(Json(self), buf) } } diff --git a/sqlx-core/src/types/text.rs b/sqlx-core/src/types/text.rs index 9ef865ad0..f5e323eea 100644 --- a/sqlx-core/src/types/text.rs +++ b/sqlx-core/src/types/text.rs @@ -115,7 +115,7 @@ where String: Encode<'q, DB>, DB: Database, { - fn encode_by_ref(&self, buf: &mut ::ArgumentBuffer<'q>) -> IsNull { + fn encode_by_ref(&self, buf: &mut ::ArgumentBuffer<'q>) -> Result { self.0.to_string().encode(buf) } } diff --git a/sqlx-macros-core/src/derives/encode.rs b/sqlx-macros-core/src/derives/encode.rs index df370717b..4e9f83abc 100644 --- a/sqlx-macros-core/src/derives/encode.rs +++ b/sqlx-macros-core/src/derives/encode.rs @@ -85,7 +85,7 @@ fn expand_derive_encode_transparent( fn encode_by_ref( &self, buf: &mut ::ArgumentBuffer<#lifetime>, - ) -> ::sqlx::encode::IsNull { + ) -> ::std::result::Result<::sqlx::encode::IsNull, ::sqlx::error::BoxDynError> { <#ty as ::sqlx::encode::Encode<#lifetime, DB>>::encode_by_ref(&self.0, buf) } @@ -124,7 +124,7 @@ fn expand_derive_encode_weak_enum( fn encode_by_ref( &self, buf: &mut ::ArgumentBuffer<'q>, - ) -> ::sqlx::encode::IsNull { + ) -> ::std::result::Result<::sqlx::encode::IsNull, ::sqlx::error::BoxDynError> { let value = match self { #(#values)* }; @@ -174,7 +174,7 @@ fn expand_derive_encode_strong_enum( fn encode_by_ref( &self, buf: &mut ::ArgumentBuffer<'q>, - ) -> ::sqlx::encode::IsNull { + ) -> ::std::result::Result<::sqlx::encode::IsNull, ::sqlx::error::BoxDynError> { let val = match self { #(#value_arms)* }; @@ -228,7 +228,7 @@ fn expand_derive_encode_struct( let id = &field.ident; parse_quote!( - encoder.encode(&self. #id); + encoder.encode(&self. #id)?; ) }); @@ -249,14 +249,14 @@ fn expand_derive_encode_struct( fn encode_by_ref( &self, buf: &mut ::sqlx::postgres::PgArgumentBuffer, - ) -> ::sqlx::encode::IsNull { + ) -> ::std::result::Result<::sqlx::encode::IsNull, ::sqlx::error::BoxDynError> { let mut encoder = ::sqlx::postgres::types::PgRecordEncoder::new(buf); #(#writes)* encoder.finish(); - ::sqlx::encode::IsNull::No + ::std::result::Result::Ok(::sqlx::encode::IsNull::No) } fn size_hint(&self) -> ::std::primitive::usize { diff --git a/sqlx-macros-core/src/query/args.rs b/sqlx-macros-core/src/query/args.rs index f72f91d1f..80d4210f2 100644 --- a/sqlx-macros-core/src/query/args.rs +++ b/sqlx-macros-core/src/query/args.rs @@ -17,7 +17,7 @@ pub fn quote_args( if input.arg_exprs.is_empty() { return Ok(quote! { - let query_args = <#db_path as ::sqlx::database::Database>::Arguments::<'_>::default(); + let query_args = ::core::result::Result::<_, ::sqlx::error::BoxDynError>::Ok(<#db_path as ::sqlx::database::Database>::Arguments::<'_>::default()); }); } @@ -109,7 +109,8 @@ pub fn quote_args( #args_count, 0 #(+ ::sqlx::encode::Encode::<#db_path>::size_hint(#arg_name))* ); - #(query_args.add(#arg_name);)* + let query_args = ::core::result::Result::<_, ::sqlx::error::BoxDynError>::Ok(query_args) + #(.and_then(move |mut query_args| query_args.add(#arg_name).map(move |()| query_args) ))*; }) } diff --git a/sqlx-macros-core/src/query/mod.rs b/sqlx-macros-core/src/query/mod.rs index cb60ea35b..cd5d14d85 100644 --- a/sqlx-macros-core/src/query/mod.rs +++ b/sqlx-macros-core/src/query/mod.rs @@ -276,7 +276,7 @@ where let sql = &input.sql; quote! { - ::sqlx::query_with::<#db_path, _>(#sql, #query_args) + ::sqlx::__query_with_result::<#db_path, _>(#sql, #query_args) } } else { match input.record_type { diff --git a/sqlx-macros-core/src/query/output.rs b/sqlx-macros-core/src/query/output.rs index 905c90306..d54fa24ca 100644 --- a/sqlx-macros-core/src/query/output.rs +++ b/sqlx-macros-core/src/query/output.rs @@ -173,7 +173,7 @@ pub fn quote_query_as( }; quote! { - ::sqlx::query_with::<#db_path, _>(#sql, #bind_args).try_map(|row: #row_path| { + ::sqlx::__query_with_result::<#db_path, _>(#sql, #bind_args).try_map(|row: #row_path| { use ::sqlx::Row as _; #(#instantiations)* @@ -216,7 +216,7 @@ pub fn quote_query_scalar( let query = &input.sql; Ok(quote! { - ::sqlx::query_scalar_with::<#db, #ty, _>(#query, #bind_args) + ::sqlx::__query_scalar_with_result::<#db, #ty, _>(#query, #bind_args) }) } diff --git a/sqlx-mysql/src/any.rs b/sqlx-mysql/src/any.rs index 2f9f9f851..d8b6b3470 100644 --- a/sqlx-mysql/src/any.rs +++ b/sqlx-mysql/src/any.rs @@ -6,7 +6,7 @@ use crate::{ use either::Either; use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; -use futures_util::{StreamExt, TryFutureExt, TryStreamExt}; +use futures_util::{stream, StreamExt, TryFutureExt, TryStreamExt}; use sqlx_core::any::{ Any, AnyArguments, AnyColumn, AnyConnectOptions, AnyConnectionBackend, AnyQueryResult, AnyRow, AnyStatement, AnyTypeInfo, AnyTypeInfoKind, @@ -16,6 +16,7 @@ use sqlx_core::database::Database; use sqlx_core::describe::Describe; use sqlx_core::executor::Executor; use sqlx_core::transaction::TransactionManager; +use std::future; sqlx_core::declare_driver_with_optional_migrate!(DRIVER = MySql); @@ -77,10 +78,15 @@ impl AnyConnectionBackend for MySqlConnection { arguments: Option>, ) -> BoxStream<'q, sqlx_core::Result>> { let persistent = arguments.is_some(); - let args = arguments.as_ref().map(AnyArguments::convert_to); + let arguments = match arguments.as_ref().map(AnyArguments::convert_to).transpose() { + Ok(arguments) => arguments, + Err(error) => { + return stream::once(future::ready(Err(sqlx_core::Error::Encode(error)))).boxed() + } + }; Box::pin( - self.run(query, args, persistent) + self.run(query, arguments, persistent) .try_flatten_stream() .map(|res| { Ok(match res? { @@ -97,10 +103,15 @@ impl AnyConnectionBackend for MySqlConnection { arguments: Option>, ) -> BoxFuture<'q, sqlx_core::Result>> { let persistent = arguments.is_some(); - let args = arguments.as_ref().map(AnyArguments::convert_to); + let arguments = arguments + .as_ref() + .map(AnyArguments::convert_to) + .transpose() + .map_err(sqlx_core::Error::Encode); Box::pin(async move { - let stream = self.run(query, args, persistent).await?; + let arguments = arguments?; + let stream = self.run(query, arguments, persistent).await?; futures_util::pin_mut!(stream); while let Some(result) = stream.try_next().await? { diff --git a/sqlx-mysql/src/arguments.rs b/sqlx-mysql/src/arguments.rs index 3731ea24d..464529cba 100644 --- a/sqlx-mysql/src/arguments.rs +++ b/sqlx-mysql/src/arguments.rs @@ -2,34 +2,38 @@ use crate::encode::{Encode, IsNull}; use crate::types::Type; use crate::{MySql, MySqlTypeInfo}; pub(crate) use sqlx_core::arguments::*; +use sqlx_core::error::BoxDynError; +use std::ops::Deref; /// Implementation of [`Arguments`] for MySQL. #[derive(Debug, Default, Clone)] pub struct MySqlArguments { pub(crate) values: Vec, pub(crate) types: Vec, - pub(crate) null_bitmap: Vec, + pub(crate) null_bitmap: NullBitMap, } impl MySqlArguments { - pub(crate) fn add<'q, T>(&mut self, value: T) + pub(crate) fn add<'q, T>(&mut self, value: T) -> Result<(), BoxDynError> where T: Encode<'q, MySql> + Type, { let ty = value.produces().unwrap_or_else(T::type_info); - let index = self.types.len(); + + let value_length_before_encoding = self.values.len(); + let is_null = match value.encode(&mut self.values) { + Ok(is_null) => is_null, + Err(error) => { + // reset the value buffer to its previous value if encoding failed so we don't leave a half-encoded value behind + self.values.truncate(value_length_before_encoding); + return Err(error); + } + }; self.types.push(ty); - self.null_bitmap.resize((index / 8) + 1, 0); + self.null_bitmap.push(is_null); - if let IsNull::Yes = value.encode(&mut self.values) { - self.null_bitmap[index / 8] |= (1 << (index % 8)) as u8; - } - } - - #[doc(hidden)] - pub fn len(&self) -> usize { - self.types.len() + Ok(()) } } @@ -41,10 +45,64 @@ impl<'q> Arguments<'q> for MySqlArguments { self.values.reserve(size); } - fn add(&mut self, value: T) + fn add(&mut self, value: T) -> Result<(), BoxDynError> where T: Encode<'q, Self::Database> + Type, { self.add(value) } + + fn len(&self) -> usize { + self.types.len() + } +} + +#[derive(Debug, Default, Clone)] +pub(crate) struct NullBitMap { + bytes: Vec, + length: usize, +} + +impl NullBitMap { + fn push(&mut self, is_null: IsNull) { + let byte_index = self.length / (u8::BITS as usize); + let bit_offset = self.length % (u8::BITS as usize); + + if bit_offset == 0 { + self.bytes.push(0); + } + + self.bytes[byte_index] |= u8::from(is_null.is_null()) << bit_offset; + self.length += 1; + } +} + +impl Deref for NullBitMap { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + &self.bytes + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn null_bit_map_should_push_is_null() { + let mut bit_map = NullBitMap::default(); + + bit_map.push(IsNull::Yes); + bit_map.push(IsNull::No); + bit_map.push(IsNull::Yes); + bit_map.push(IsNull::No); + bit_map.push(IsNull::Yes); + bit_map.push(IsNull::No); + bit_map.push(IsNull::Yes); + bit_map.push(IsNull::No); + bit_map.push(IsNull::Yes); + + assert_eq!([0b01010101, 0b1].as_slice(), bit_map.deref()); + } } diff --git a/sqlx-mysql/src/connection/executor.rs b/sqlx-mysql/src/connection/executor.rs index 21fec1ec6..474337cd6 100644 --- a/sqlx-mysql/src/connection/executor.rs +++ b/sqlx-mysql/src/connection/executor.rs @@ -244,10 +244,11 @@ impl<'c> Executor<'c> for &'c mut MySqlConnection { E: Execute<'q, Self::Database>, { let sql = query.sql(); - let arguments = query.take_arguments(); + let arguments = query.take_arguments().map_err(Error::Encode); let persistent = query.persistent(); Box::pin(try_stream! { + let arguments = arguments?; let s = self.run(sql, arguments, persistent).await?; pin_mut!(s); diff --git a/sqlx-mysql/src/protocol/statement/execute.rs b/sqlx-mysql/src/protocol/statement/execute.rs index 8bd367e25..e1bf998b4 100644 --- a/sqlx-mysql/src/protocol/statement/execute.rs +++ b/sqlx-mysql/src/protocol/statement/execute.rs @@ -19,7 +19,7 @@ impl<'q> Encode<'_, Capabilities> for Execute<'q> { buf.extend(&1_u32.to_le_bytes()); // iterations (always 1): int<4> if !self.arguments.types.is_empty() { - buf.extend(&*self.arguments.null_bitmap); + buf.extend_from_slice(&self.arguments.null_bitmap); buf.push(1); // send type to server for ty in &self.arguments.types { diff --git a/sqlx-mysql/src/types/bigdecimal.rs b/sqlx-mysql/src/types/bigdecimal.rs index d072db27c..11bca0480 100644 --- a/sqlx-mysql/src/types/bigdecimal.rs +++ b/sqlx-mysql/src/types/bigdecimal.rs @@ -19,10 +19,10 @@ impl Type for BigDecimal { } impl Encode<'_, MySql> for BigDecimal { - fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.put_str_lenenc(&self.to_string()); - IsNull::No + Ok(IsNull::No) } } diff --git a/sqlx-mysql/src/types/bool.rs b/sqlx-mysql/src/types/bool.rs index 92793cdf9..2ba12e917 100644 --- a/sqlx-mysql/src/types/bool.rs +++ b/sqlx-mysql/src/types/bool.rs @@ -32,7 +32,7 @@ impl Type for bool { } impl Encode<'_, MySql> for bool { - fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { >::encode(*self as i8, buf) } } diff --git a/sqlx-mysql/src/types/bytes.rs b/sqlx-mysql/src/types/bytes.rs index 69f132756..ade079ad4 100644 --- a/sqlx-mysql/src/types/bytes.rs +++ b/sqlx-mysql/src/types/bytes.rs @@ -27,10 +27,10 @@ impl Type for [u8] { } impl Encode<'_, MySql> for &'_ [u8] { - fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.put_bytes_lenenc(self); - IsNull::No + Ok(IsNull::No) } } @@ -51,7 +51,7 @@ impl Type for Box<[u8]> { } impl Encode<'_, MySql> for Box<[u8]> { - fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { <&[u8] as Encode>::encode(self.as_ref(), buf) } } @@ -73,7 +73,7 @@ impl Type for Vec { } impl Encode<'_, MySql> for Vec { - fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { <&[u8] as Encode>::encode(&**self, buf) } } diff --git a/sqlx-mysql/src/types/chrono.rs b/sqlx-mysql/src/types/chrono.rs index 97dc06eda..9013ce72c 100644 --- a/sqlx-mysql/src/types/chrono.rs +++ b/sqlx-mysql/src/types/chrono.rs @@ -24,7 +24,7 @@ impl Type for DateTime { /// Note: assumes the connection's `time_zone` is set to `+00:00` (UTC). impl Encode<'_, MySql> for DateTime { - fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { Encode::::encode(&self.naive_utc(), buf) } } @@ -50,7 +50,7 @@ impl Type for DateTime { /// Note: assumes the connection's `time_zone` is set to `+00:00` (UTC). impl Encode<'_, MySql> for DateTime { - fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { Encode::::encode(&self.naive_utc(), buf) } } @@ -69,7 +69,7 @@ impl Type for NaiveTime { } impl Encode<'_, MySql> for NaiveTime { - fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { let len = Encode::::size_hint(self) - 1; buf.push(len as u8); @@ -82,7 +82,7 @@ impl Encode<'_, MySql> for NaiveTime { encode_time(self, len > 9, buf); - IsNull::No + Ok(IsNull::No) } fn size_hint(&self) -> usize { @@ -174,12 +174,12 @@ impl Type for NaiveDate { } impl Encode<'_, MySql> for NaiveDate { - fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.push(4); - encode_date(self, buf); + encode_date(self, buf)?; - IsNull::No + Ok(IsNull::No) } fn size_hint(&self) -> usize { @@ -216,17 +216,17 @@ impl Type for NaiveDateTime { } impl Encode<'_, MySql> for NaiveDateTime { - fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { let len = Encode::::size_hint(self) - 1; buf.push(len as u8); - encode_date(&self.date(), buf); + encode_date(&self.date(), buf)?; if len > 4 { encode_time(&self.time(), len > 8, buf); } - IsNull::No + Ok(IsNull::No) } fn size_hint(&self) -> usize { @@ -282,14 +282,16 @@ impl<'r> Decode<'r, MySql> for NaiveDateTime { } } -fn encode_date(date: &NaiveDate, buf: &mut Vec) { +fn encode_date(date: &NaiveDate, buf: &mut Vec) -> Result<(), BoxDynError> { // MySQL supports years from 1000 - 9999 let year = u16::try_from(date.year()) - .unwrap_or_else(|_| panic!("NaiveDateTime out of range for Mysql: {date}")); + .map_err(|_| format!("NaiveDateTime out of range for Mysql: {date}"))?; buf.extend_from_slice(&year.to_le_bytes()); buf.push(date.month() as u8); buf.push(date.day() as u8); + + Ok(()) } fn decode_date(mut buf: &[u8]) -> Result, BoxDynError> { diff --git a/sqlx-mysql/src/types/float.rs b/sqlx-mysql/src/types/float.rs index 0b36a5e87..13809f39f 100644 --- a/sqlx-mysql/src/types/float.rs +++ b/sqlx-mysql/src/types/float.rs @@ -33,18 +33,18 @@ impl Type for f64 { } impl Encode<'_, MySql> for f32 { - fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.extend(&self.to_le_bytes()); - IsNull::No + Ok(IsNull::No) } } impl Encode<'_, MySql> for f64 { - fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.extend(&self.to_le_bytes()); - IsNull::No + Ok(IsNull::No) } } diff --git a/sqlx-mysql/src/types/inet.rs b/sqlx-mysql/src/types/inet.rs index 385ca4a61..19e59028e 100644 --- a/sqlx-mysql/src/types/inet.rs +++ b/sqlx-mysql/src/types/inet.rs @@ -18,10 +18,10 @@ impl Type for Ipv4Addr { } impl Encode<'_, MySql> for Ipv4Addr { - fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.put_str_lenenc(&self.to_string()); - IsNull::No + Ok(IsNull::No) } } @@ -46,10 +46,10 @@ impl Type for Ipv6Addr { } impl Encode<'_, MySql> for Ipv6Addr { - fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.put_str_lenenc(&self.to_string()); - IsNull::No + Ok(IsNull::No) } } @@ -74,10 +74,10 @@ impl Type for IpAddr { } impl Encode<'_, MySql> for IpAddr { - fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.put_str_lenenc(&self.to_string()); - IsNull::No + Ok(IsNull::No) } } diff --git a/sqlx-mysql/src/types/int.rs b/sqlx-mysql/src/types/int.rs index c4896fa93..0e5b16225 100644 --- a/sqlx-mysql/src/types/int.rs +++ b/sqlx-mysql/src/types/int.rs @@ -59,34 +59,34 @@ impl Type for i64 { } impl Encode<'_, MySql> for i8 { - fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.extend(&self.to_le_bytes()); - IsNull::No + Ok(IsNull::No) } } impl Encode<'_, MySql> for i16 { - fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.extend(&self.to_le_bytes()); - IsNull::No + Ok(IsNull::No) } } impl Encode<'_, MySql> for i32 { - fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.extend(&self.to_le_bytes()); - IsNull::No + Ok(IsNull::No) } } impl Encode<'_, MySql> for i64 { - fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.extend(&self.to_le_bytes()); - IsNull::No + Ok(IsNull::No) } } diff --git a/sqlx-mysql/src/types/json.rs b/sqlx-mysql/src/types/json.rs index b83ba83bc..b47baa355 100644 --- a/sqlx-mysql/src/types/json.rs +++ b/sqlx-mysql/src/types/json.rs @@ -26,7 +26,7 @@ impl Encode<'_, MySql> for Json where T: Serialize, { - fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { // Encode JSON as a length-prefixed string. // // The previous implementation encoded into an intermediate buffer to get the final length. @@ -47,14 +47,14 @@ where buf.extend_from_slice(&[0u8; 9]); let encode_start = buf.len(); - self.encode_to(buf); + self.encode_to(buf)?; let encoded_len = (buf.len() - encode_start) as u64; // This prefix indicates that the following 8 bytes are a little-endian integer. buf[lenenc_start] = 0xFE; buf[lenenc_start + 1..][..8].copy_from_slice(&encoded_len.to_le_bytes()); - IsNull::No + Ok(IsNull::No) } } diff --git a/sqlx-mysql/src/types/mysql_time.rs b/sqlx-mysql/src/types/mysql_time.rs index d1275f8b1..f66d250a8 100644 --- a/sqlx-mysql/src/types/mysql_time.rs +++ b/sqlx-mysql/src/types/mysql_time.rs @@ -410,10 +410,13 @@ impl<'r> Decode<'r, MySql> for MySqlTime { } impl<'q> Encode<'q, MySql> for MySqlTime { - fn encode_by_ref(&self, buf: &mut ::ArgumentBuffer<'q>) -> IsNull { + fn encode_by_ref( + &self, + buf: &mut ::ArgumentBuffer<'q>, + ) -> Result { if self.is_zero() { buf.put_u8(0); - return IsNull::No; + return Ok(IsNull::No); } buf.put_u8(self.encoded_len()); @@ -438,7 +441,7 @@ impl<'q> Encode<'q, MySql> for MySqlTime { buf.put_u32_le(microseconds); } - IsNull::No + Ok(IsNull::No) } fn size_hint(&self) -> usize { diff --git a/sqlx-mysql/src/types/rust_decimal.rs b/sqlx-mysql/src/types/rust_decimal.rs index 49ab2ded5..6e78243c7 100644 --- a/sqlx-mysql/src/types/rust_decimal.rs +++ b/sqlx-mysql/src/types/rust_decimal.rs @@ -19,10 +19,10 @@ impl Type for Decimal { } impl Encode<'_, MySql> for Decimal { - fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.put_str_lenenc(&self.to_string()); - IsNull::No + Ok(IsNull::No) } } diff --git a/sqlx-mysql/src/types/str.rs b/sqlx-mysql/src/types/str.rs index 7bb2f039b..5c2f60ff3 100644 --- a/sqlx-mysql/src/types/str.rs +++ b/sqlx-mysql/src/types/str.rs @@ -49,10 +49,10 @@ impl Type for str { } impl Encode<'_, MySql> for &'_ str { - fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.put_str_lenenc(self); - IsNull::No + Ok(IsNull::No) } } @@ -73,7 +73,7 @@ impl Type for Box { } impl Encode<'_, MySql> for Box { - fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { <&str as Encode>::encode(&**self, buf) } } @@ -95,7 +95,7 @@ impl Type for String { } impl Encode<'_, MySql> for String { - fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { <&str as Encode>::encode(&**self, buf) } } @@ -117,7 +117,7 @@ impl Type for Cow<'_, str> { } impl Encode<'_, MySql> for Cow<'_, str> { - fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { match self { Cow::Borrowed(str) => <&str as Encode>::encode(*str, buf), Cow::Owned(str) => <&str as Encode>::encode(&**str, buf), diff --git a/sqlx-mysql/src/types/text.rs b/sqlx-mysql/src/types/text.rs index 6b6172897..ad61c1bee 100644 --- a/sqlx-mysql/src/types/text.rs +++ b/sqlx-mysql/src/types/text.rs @@ -20,7 +20,7 @@ impl<'q, T> Encode<'q, MySql> for Text where T: Display, { - fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { // We can't really do the trick like with Postgres where we reserve the space for the // length up-front and then overwrite it later, because MySQL appears to enforce that // length-encoded integers use the smallest encoding for the value: diff --git a/sqlx-mysql/src/types/time.rs b/sqlx-mysql/src/types/time.rs index dcdc4b5be..2d53839ed 100644 --- a/sqlx-mysql/src/types/time.rs +++ b/sqlx-mysql/src/types/time.rs @@ -23,7 +23,7 @@ impl Type for OffsetDateTime { } impl Encode<'_, MySql> for OffsetDateTime { - fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { let utc_dt = self.to_offset(UtcOffset::UTC); let primitive_dt = PrimitiveDateTime::new(utc_dt.date(), utc_dt.time()); @@ -46,7 +46,7 @@ impl Type for Time { } impl Encode<'_, MySql> for Time { - fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { let len = Encode::::size_hint(self) - 1; buf.push(len as u8); @@ -59,7 +59,7 @@ impl Encode<'_, MySql> for Time { encode_time(self, len > 9, buf); - IsNull::No + Ok(IsNull::No) } fn size_hint(&self) -> usize { @@ -149,12 +149,12 @@ impl Type for Date { } impl Encode<'_, MySql> for Date { - fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.push(4); - encode_date(self, buf); + encode_date(self, buf)?; - IsNull::No + Ok(IsNull::No) } fn size_hint(&self) -> usize { @@ -190,17 +190,17 @@ impl Type for PrimitiveDateTime { } impl Encode<'_, MySql> for PrimitiveDateTime { - fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { let len = Encode::::size_hint(self) - 1; buf.push(len as u8); - encode_date(&self.date(), buf); + encode_date(&self.date(), buf)?; if len > 4 { encode_time(&self.time(), len > 8, buf); } - IsNull::No + Ok(IsNull::No) } fn size_hint(&self) -> usize { @@ -267,14 +267,16 @@ impl<'r> Decode<'r, MySql> for PrimitiveDateTime { } } -fn encode_date(date: &Date, buf: &mut Vec) { +fn encode_date(date: &Date, buf: &mut Vec) -> Result<(), BoxDynError> { // MySQL supports years from 1000 - 9999 - let year = u16::try_from(date.year()) - .unwrap_or_else(|_| panic!("Date out of range for Mysql: {date}")); + let year = + u16::try_from(date.year()).map_err(|_| format!("Date out of range for Mysql: {date}"))?; buf.extend_from_slice(&year.to_le_bytes()); buf.push(date.month().into()); buf.push(date.day()); + + Ok(()) } fn decode_date(buf: &[u8]) -> Result, BoxDynError> { diff --git a/sqlx-mysql/src/types/uint.rs b/sqlx-mysql/src/types/uint.rs index ef383797c..79b131e8a 100644 --- a/sqlx-mysql/src/types/uint.rs +++ b/sqlx-mysql/src/types/uint.rs @@ -69,34 +69,34 @@ impl Type for u64 { } impl Encode<'_, MySql> for u8 { - fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.extend(&self.to_le_bytes()); - IsNull::No + Ok(IsNull::No) } } impl Encode<'_, MySql> for u16 { - fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.extend(&self.to_le_bytes()); - IsNull::No + Ok(IsNull::No) } } impl Encode<'_, MySql> for u32 { - fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.extend(&self.to_le_bytes()); - IsNull::No + Ok(IsNull::No) } } impl Encode<'_, MySql> for u64 { - fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.extend(&self.to_le_bytes()); - IsNull::No + Ok(IsNull::No) } } diff --git a/sqlx-mysql/src/types/uuid.rs b/sqlx-mysql/src/types/uuid.rs index c18fdb785..53adfac0e 100644 --- a/sqlx-mysql/src/types/uuid.rs +++ b/sqlx-mysql/src/types/uuid.rs @@ -21,10 +21,10 @@ impl Type for Uuid { } impl Encode<'_, MySql> for Uuid { - fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.put_bytes_lenenc(self.as_bytes()); - IsNull::No + Ok(IsNull::No) } } @@ -49,10 +49,10 @@ impl Type for Hyphenated { } impl Encode<'_, MySql> for Hyphenated { - fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.put_str_lenenc(&self.to_string()); - IsNull::No + Ok(IsNull::No) } } @@ -79,10 +79,10 @@ impl Type for Simple { } impl Encode<'_, MySql> for Simple { - fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.put_str_lenenc(&self.to_string()); - IsNull::No + Ok(IsNull::No) } } diff --git a/sqlx-postgres/src/any.rs b/sqlx-postgres/src/any.rs index 542e7f007..f1424ab6c 100644 --- a/sqlx-postgres/src/any.rs +++ b/sqlx-postgres/src/any.rs @@ -4,7 +4,8 @@ use crate::{ }; use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; -use futures_util::{StreamExt, TryFutureExt, TryStreamExt}; +use futures_util::{stream, StreamExt, TryFutureExt, TryStreamExt}; +use std::future; pub use sqlx_core::any::*; @@ -76,10 +77,15 @@ impl AnyConnectionBackend for PgConnection { arguments: Option>, ) -> BoxStream<'q, sqlx_core::Result>> { let persistent = arguments.is_some(); - let args = arguments.as_ref().map(AnyArguments::convert_to); + let arguments = match arguments.as_ref().map(AnyArguments::convert_to).transpose() { + Ok(arguments) => arguments, + Err(error) => { + return stream::once(future::ready(Err(sqlx_core::Error::Encode(error)))).boxed() + } + }; Box::pin( - self.run(query, args, 0, persistent, None) + self.run(query, arguments, 0, persistent, None) .try_flatten_stream() .map( move |res: sqlx_core::Result>| match res? { @@ -96,10 +102,15 @@ impl AnyConnectionBackend for PgConnection { arguments: Option>, ) -> BoxFuture<'q, sqlx_core::Result>> { let persistent = arguments.is_some(); - let args = arguments.as_ref().map(AnyArguments::convert_to); + let arguments = arguments + .as_ref() + .map(AnyArguments::convert_to) + .transpose() + .map_err(sqlx_core::Error::Encode); Box::pin(async move { - let stream = self.run(query, args, 1, persistent, None).await?; + let arguments = arguments?; + let stream = self.run(query, arguments, 1, persistent, None).await?; futures_util::pin_mut!(stream); if let Some(Either::Right(row)) = stream.try_next().await? { diff --git a/sqlx-postgres/src/arguments.rs b/sqlx-postgres/src/arguments.rs index 40e1c11e3..dd6f239f8 100644 --- a/sqlx-postgres/src/arguments.rs +++ b/sqlx-postgres/src/arguments.rs @@ -8,6 +8,7 @@ use crate::types::Type; use crate::{PgConnection, PgTypeInfo, Postgres}; pub(crate) use sqlx_core::arguments::Arguments; +use sqlx_core::error::BoxDynError; // TODO: buf.patch(|| ...) is a poor name, can we think of a better name? Maybe `buf.lazy(||)` ? // TODO: Extend the patch system to support dynamic lengths @@ -59,19 +60,27 @@ pub struct PgArguments { } impl PgArguments { - pub(crate) fn add<'q, T>(&mut self, value: T) + pub(crate) fn add<'q, T>(&mut self, value: T) -> Result<(), BoxDynError> where T: Encode<'q, Postgres> + Type, { - // remember the type information for this value - self.types - .push(value.produces().unwrap_or_else(T::type_info)); + let type_info = value.produces().unwrap_or_else(T::type_info); + + let buffer_snapshot = self.buffer.snapshot(); // encode the value into our buffer - self.buffer.encode(value); + if let Err(error) = self.buffer.encode(value) { + // reset the value buffer to its previous value if encoding failed so we don't leave a half-encoded value behind + self.buffer.reset_to_snapshot(buffer_snapshot); + return Err(error); + }; + // remember the type information for this value + self.types.push(type_info); // increment the number of arguments we are tracking self.buffer.count += 1; + + Ok(()) } // Apply patches @@ -112,7 +121,7 @@ impl<'q> Arguments<'q> for PgArguments { self.buffer.reserve(size); } - fn add(&mut self, value: T) + fn add(&mut self, value: T) -> Result<(), BoxDynError> where T: Encode<'q, Self::Database> + Type, { @@ -122,10 +131,14 @@ impl<'q> Arguments<'q> for PgArguments { fn format_placeholder(&self, writer: &mut W) -> fmt::Result { write!(writer, "${}", self.buffer.count) } + + fn len(&self) -> usize { + self.buffer.count + } } impl PgArgumentBuffer { - pub(crate) fn encode<'q, T>(&mut self, value: T) + pub(crate) fn encode<'q, T>(&mut self, value: T) -> Result<(), BoxDynError> where T: Encode<'q, Postgres>, { @@ -134,7 +147,7 @@ impl PgArgumentBuffer { self.extend(&[0; 4]); // encode the value into our buffer - let len = if let IsNull::No = value.encode(self) { + let len = if let IsNull::No = value.encode(self)? { (self.len() - offset - 4) as i32 } else { // Write a -1 to indicate NULL @@ -145,6 +158,8 @@ impl PgArgumentBuffer { // write the len to the beginning of the value self[offset..(offset + 4)].copy_from_slice(&len.to_be_bytes()); + + Ok(()) } // Adds a callback to be invoked later when we know the parameter type @@ -167,6 +182,44 @@ impl PgArgumentBuffer { self.extend_from_slice(&0_u32.to_be_bytes()); self.type_holes.push((offset, type_name.clone())); } + + fn snapshot(&self) -> PgArgumentBufferSnapshot { + let Self { + buffer, + count, + patches, + type_holes, + } = self; + + PgArgumentBufferSnapshot { + buffer_length: buffer.len(), + count: *count, + patches_length: patches.len(), + type_holes_length: type_holes.len(), + } + } + + fn reset_to_snapshot( + &mut self, + PgArgumentBufferSnapshot { + buffer_length, + count, + patches_length, + type_holes_length, + }: PgArgumentBufferSnapshot, + ) { + self.buffer.truncate(buffer_length); + self.count = count; + self.patches.truncate(patches_length); + self.type_holes.truncate(type_holes_length); + } +} + +struct PgArgumentBufferSnapshot { + buffer_length: usize, + count: usize, + patches_length: usize, + type_holes_length: usize, } impl Deref for PgArgumentBuffer { diff --git a/sqlx-postgres/src/connection/describe.rs b/sqlx-postgres/src/connection/describe.rs index 189ae3389..71f9b9b31 100644 --- a/sqlx-postgres/src/connection/describe.rs +++ b/sqlx-postgres/src/connection/describe.rs @@ -382,9 +382,10 @@ WHERE rngtypid = $1 bind + 2 ); - args.add(i as i32); - args.add(column.relation_id); - args.add(column.relation_attribute_no); + args.add(i as i32).map_err(Error::Encode)?; + args.add(column.relation_id).map_err(Error::Encode)?; + args.add(column.relation_attribute_no) + .map_err(Error::Encode)?; } nullable_query.push_str( diff --git a/sqlx-postgres/src/connection/executor.rs b/sqlx-postgres/src/connection/executor.rs index b1ce29c51..96155d928 100644 --- a/sqlx-postgres/src/connection/executor.rs +++ b/sqlx-postgres/src/connection/executor.rs @@ -370,10 +370,11 @@ impl<'c> Executor<'c> for &'c mut PgConnection { { let sql = query.sql(); let metadata = query.statement().map(|s| Arc::clone(&s.metadata)); - let arguments = query.take_arguments(); + let arguments = query.take_arguments().map_err(Error::Encode); let persistent = query.persistent(); Box::pin(try_stream! { + let arguments = arguments?; let s = self.run(sql, arguments, 0, persistent, metadata).await?; pin_mut!(s); @@ -395,10 +396,11 @@ impl<'c> Executor<'c> for &'c mut PgConnection { { let sql = query.sql(); let metadata = query.statement().map(|s| Arc::clone(&s.metadata)); - let arguments = query.take_arguments(); + let arguments = query.take_arguments().map_err(Error::Encode); let persistent = query.persistent(); Box::pin(async move { + let arguments = arguments?; let s = self.run(sql, arguments, 1, persistent, metadata).await?; pin_mut!(s); diff --git a/sqlx-postgres/src/types/array.rs b/sqlx-postgres/src/types/array.rs index 8ffcf5a35..f594ab8f3 100644 --- a/sqlx-postgres/src/types/array.rs +++ b/sqlx-postgres/src/types/array.rs @@ -136,7 +136,7 @@ where T: Encode<'q, Postgres>, { #[inline] - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { self.as_slice().encode_by_ref(buf) } } @@ -146,7 +146,7 @@ where for<'a> &'a [T]: Encode<'q, Postgres>, T: Encode<'q, Postgres>, { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { self.as_slice().encode_by_ref(buf) } } @@ -155,7 +155,7 @@ impl<'q, T> Encode<'q, Postgres> for &'_ [T] where T: Encode<'q, Postgres> + Type, { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { let type_info = if self.len() < 1 { T::type_info() } else { @@ -178,10 +178,10 @@ where buf.extend(&1_i32.to_be_bytes()); // lower bound for element in self.iter() { - buf.encode(element); + buf.encode(element)?; } - IsNull::No + Ok(IsNull::No) } } diff --git a/sqlx-postgres/src/types/bigdecimal.rs b/sqlx-postgres/src/types/bigdecimal.rs index f42a1794b..5a6e500d3 100644 --- a/sqlx-postgres/src/types/bigdecimal.rs +++ b/sqlx-postgres/src/types/bigdecimal.rs @@ -137,24 +137,10 @@ impl TryFrom<&'_ BigDecimal> for PgNumeric { #[doc=include_str!("bigdecimal-range.md")] impl Encode<'_, Postgres> for BigDecimal { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { - // If the argument is too big, then we replace it with a less big argument. - // This less big argument is already outside the range of allowed PostgreSQL DECIMAL, which - // means that PostgreSQL will return the 22P03 error kind upon receiving it. This is the - // expected error, and the user should be ready to handle it anyway. - PgNumeric::try_from(self) - .unwrap_or_else(|_| { - PgNumeric::Number { - digits: vec![1], - // This is larger than the maximum allowed value, so Postgres should return an error. - scale: 0x4000, - weight: 0, - sign: sign_to_pg(self.sign()), - } - }) - .encode(buf); + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + PgNumeric::try_from(self)?.encode(buf); - IsNull::No + Ok(IsNull::No) } fn size_hint(&self) -> usize { diff --git a/sqlx-postgres/src/types/bit_vec.rs b/sqlx-postgres/src/types/bit_vec.rs index e1ad3fb50..2cb9943ce 100644 --- a/sqlx-postgres/src/types/bit_vec.rs +++ b/sqlx-postgres/src/types/bit_vec.rs @@ -30,11 +30,11 @@ impl PgHasArrayType for BitVec { } impl Encode<'_, Postgres> for BitVec { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { buf.extend(&(self.len() as i32).to_be_bytes()); buf.extend(self.to_bytes()); - IsNull::No + Ok(IsNull::No) } fn size_hint(&self) -> usize { diff --git a/sqlx-postgres/src/types/bool.rs b/sqlx-postgres/src/types/bool.rs index ec7538119..8c3e140d3 100644 --- a/sqlx-postgres/src/types/bool.rs +++ b/sqlx-postgres/src/types/bool.rs @@ -17,10 +17,10 @@ impl PgHasArrayType for bool { } impl Encode<'_, Postgres> for bool { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { buf.push(*self as u8); - IsNull::No + Ok(IsNull::No) } } diff --git a/sqlx-postgres/src/types/bytes.rs b/sqlx-postgres/src/types/bytes.rs index b0c05ca84..45968837a 100644 --- a/sqlx-postgres/src/types/bytes.rs +++ b/sqlx-postgres/src/types/bytes.rs @@ -35,27 +35,27 @@ impl PgHasArrayType for [u8; N] { } impl Encode<'_, Postgres> for &'_ [u8] { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { buf.extend_from_slice(self); - IsNull::No + Ok(IsNull::No) } } impl Encode<'_, Postgres> for Box<[u8]> { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { <&[u8] as Encode>::encode(self.as_ref(), buf) } } impl Encode<'_, Postgres> for Vec { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { <&[u8] as Encode>::encode(self, buf) } } impl Encode<'_, Postgres> for [u8; N] { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { <&[u8] as Encode>::encode(self.as_slice(), buf) } } diff --git a/sqlx-postgres/src/types/chrono/date.rs b/sqlx-postgres/src/types/chrono/date.rs index da10bd1ac..4425f66e9 100644 --- a/sqlx-postgres/src/types/chrono/date.rs +++ b/sqlx-postgres/src/types/chrono/date.rs @@ -21,7 +21,7 @@ impl PgHasArrayType for NaiveDate { } impl Encode<'_, Postgres> for NaiveDate { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { // DATE is encoded as the days since epoch let days = (*self - postgres_epoch_date()).num_days() as i32; Encode::::encode(&days, buf) diff --git a/sqlx-postgres/src/types/chrono/datetime.rs b/sqlx-postgres/src/types/chrono/datetime.rs index 72aea4945..7487e9ca5 100644 --- a/sqlx-postgres/src/types/chrono/datetime.rs +++ b/sqlx-postgres/src/types/chrono/datetime.rs @@ -33,12 +33,11 @@ impl PgHasArrayType for DateTime { } impl Encode<'_, Postgres> for NaiveDateTime { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { - // FIXME: We should *really* be returning an error, Encode needs to be fallible + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { // TIMESTAMP is encoded as the microseconds since the epoch let us = (*self - postgres_epoch_datetime()) .num_microseconds() - .unwrap_or_else(|| panic!("NaiveDateTime out of range for Postgres: {self:?}")); + .ok_or_else(|| format!("NaiveDateTime out of range for Postgres: {self:?}"))?; Encode::::encode(&us, buf) } @@ -76,7 +75,7 @@ impl<'r> Decode<'r, Postgres> for NaiveDateTime { } impl Encode<'_, Postgres> for DateTime { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { Encode::::encode(self.naive_utc(), buf) } diff --git a/sqlx-postgres/src/types/chrono/time.rs b/sqlx-postgres/src/types/chrono/time.rs index 3bdd65ee3..31b212027 100644 --- a/sqlx-postgres/src/types/chrono/time.rs +++ b/sqlx-postgres/src/types/chrono/time.rs @@ -19,10 +19,11 @@ impl PgHasArrayType for NaiveTime { } impl Encode<'_, Postgres> for NaiveTime { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { // TIME is encoded as the microseconds since midnight - // NOTE: panic! is on overflow and 1 day does not have enough micros to overflow - let us = (*self - NaiveTime::default()).num_microseconds().unwrap(); + let us = (*self - NaiveTime::default()) + .num_microseconds() + .ok_or_else(|| format!("Time out of range for PostgreSQL: {self}"))?; Encode::::encode(&us, buf) } diff --git a/sqlx-postgres/src/types/citext.rs b/sqlx-postgres/src/types/citext.rs index 9fc131d7c..c0316ac82 100644 --- a/sqlx-postgres/src/types/citext.rs +++ b/sqlx-postgres/src/types/citext.rs @@ -94,7 +94,7 @@ impl PgHasArrayType for PgCiText { } impl Encode<'_, Postgres> for PgCiText { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { <&str as Encode>::encode(&**self, buf) } } diff --git a/sqlx-postgres/src/types/float.rs b/sqlx-postgres/src/types/float.rs index 2cb258d91..116a28c2d 100644 --- a/sqlx-postgres/src/types/float.rs +++ b/sqlx-postgres/src/types/float.rs @@ -19,10 +19,10 @@ impl PgHasArrayType for f32 { } impl Encode<'_, Postgres> for f32 { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { buf.extend(&self.to_be_bytes()); - IsNull::No + Ok(IsNull::No) } } @@ -48,10 +48,10 @@ impl PgHasArrayType for f64 { } impl Encode<'_, Postgres> for f64 { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { buf.extend(&self.to_be_bytes()); - IsNull::No + Ok(IsNull::No) } } diff --git a/sqlx-postgres/src/types/int.rs b/sqlx-postgres/src/types/int.rs index 6c9adb1a7..0c852d52e 100644 --- a/sqlx-postgres/src/types/int.rs +++ b/sqlx-postgres/src/types/int.rs @@ -44,10 +44,10 @@ impl PgHasArrayType for i8 { } impl Encode<'_, Postgres> for i8 { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { buf.extend(&self.to_be_bytes()); - IsNull::No + Ok(IsNull::No) } } @@ -89,10 +89,10 @@ impl PgHasArrayType for i16 { } impl Encode<'_, Postgres> for i16 { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { buf.extend(&self.to_be_bytes()); - IsNull::No + Ok(IsNull::No) } } @@ -115,10 +115,10 @@ impl PgHasArrayType for i32 { } impl Encode<'_, Postgres> for i32 { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { buf.extend(&self.to_be_bytes()); - IsNull::No + Ok(IsNull::No) } } @@ -141,10 +141,10 @@ impl PgHasArrayType for i64 { } impl Encode<'_, Postgres> for i64 { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { buf.extend(&self.to_be_bytes()); - IsNull::No + Ok(IsNull::No) } } diff --git a/sqlx-postgres/src/types/interval.rs b/sqlx-postgres/src/types/interval.rs index 0712a26a7..07e521c8f 100644 --- a/sqlx-postgres/src/types/interval.rs +++ b/sqlx-postgres/src/types/interval.rs @@ -54,12 +54,12 @@ impl<'de> Decode<'de, Postgres> for PgInterval { } impl Encode<'_, Postgres> for PgInterval { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { buf.extend(&self.microseconds.to_be_bytes()); buf.extend(&self.days.to_be_bytes()); buf.extend(&self.months.to_be_bytes()); - IsNull::No + Ok(IsNull::No) } fn size_hint(&self) -> usize { @@ -83,10 +83,8 @@ impl PgHasArrayType for std::time::Duration { } impl Encode<'_, Postgres> for std::time::Duration { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { - PgInterval::try_from(*self) - .expect("failed to encode `std::time::Duration`") - .encode_by_ref(buf) + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + PgInterval::try_from(*self)?.encode_by_ref(buf) } fn size_hint(&self) -> usize { @@ -130,8 +128,8 @@ impl PgHasArrayType for chrono::Duration { #[cfg(feature = "chrono")] impl Encode<'_, Postgres> for chrono::Duration { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { - let pg_interval = PgInterval::try_from(*self).expect("Failed to encode chrono::Duration"); + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + let pg_interval = PgInterval::try_from(*self)?; pg_interval.encode_by_ref(buf) } @@ -192,8 +190,8 @@ impl PgHasArrayType for time::Duration { #[cfg(feature = "time")] impl Encode<'_, Postgres> for time::Duration { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { - let pg_interval = PgInterval::try_from(*self).expect("Failed to encode time::Duration"); + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + let pg_interval = PgInterval::try_from(*self)?; pg_interval.encode_by_ref(buf) } @@ -234,7 +232,7 @@ fn test_encode_interval() { }; assert!(matches!( Encode::::encode(&interval, &mut buf), - IsNull::No + Ok(IsNull::No) )); assert_eq!(&**buf, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]); buf.clear(); @@ -246,7 +244,7 @@ fn test_encode_interval() { }; assert!(matches!( Encode::::encode(&interval, &mut buf), - IsNull::No + Ok(IsNull::No) )); assert_eq!(&**buf, [0, 0, 0, 0, 0, 0, 3, 232, 0, 0, 0, 0, 0, 0, 0, 0]); buf.clear(); @@ -258,7 +256,7 @@ fn test_encode_interval() { }; assert!(matches!( Encode::::encode(&interval, &mut buf), - IsNull::No + Ok(IsNull::No) )); assert_eq!(&**buf, [0, 0, 0, 0, 0, 15, 66, 64, 0, 0, 0, 0, 0, 0, 0, 0]); buf.clear(); @@ -270,7 +268,7 @@ fn test_encode_interval() { }; assert!(matches!( Encode::::encode(&interval, &mut buf), - IsNull::No + Ok(IsNull::No) )); assert_eq!( &**buf, @@ -285,7 +283,7 @@ fn test_encode_interval() { }; assert!(matches!( Encode::::encode(&interval, &mut buf), - IsNull::No + Ok(IsNull::No) )); assert_eq!(&**buf, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]); buf.clear(); @@ -297,7 +295,7 @@ fn test_encode_interval() { }; assert!(matches!( Encode::::encode(&interval, &mut buf), - IsNull::No + Ok(IsNull::No) )); assert_eq!(&**buf, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]); buf.clear(); diff --git a/sqlx-postgres/src/types/ipaddr.rs b/sqlx-postgres/src/types/ipaddr.rs index 98898ebca..ee587eda1 100644 --- a/sqlx-postgres/src/types/ipaddr.rs +++ b/sqlx-postgres/src/types/ipaddr.rs @@ -35,7 +35,7 @@ impl<'db> Encode<'db, Postgres> for IpAddr where IpNetwork: Encode<'db, Postgres>, { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { IpNetwork::from(*self).encode_by_ref(buf) } diff --git a/sqlx-postgres/src/types/ipnetwork.rs b/sqlx-postgres/src/types/ipnetwork.rs index 539c88ed6..4f619ba99 100644 --- a/sqlx-postgres/src/types/ipnetwork.rs +++ b/sqlx-postgres/src/types/ipnetwork.rs @@ -36,7 +36,7 @@ impl PgHasArrayType for IpNetwork { } impl Encode<'_, Postgres> for IpNetwork { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { // https://github.com/postgres/postgres/blob/574925bfd0a8175f6e161936ea11d9695677ba09/src/backend/utils/adt/network.c#L293 // https://github.com/postgres/postgres/blob/574925bfd0a8175f6e161936ea11d9695677ba09/src/backend/utils/adt/network.c#L271 @@ -58,7 +58,7 @@ impl Encode<'_, Postgres> for IpNetwork { } } - IsNull::No + Ok(IsNull::No) } fn size_hint(&self) -> usize { diff --git a/sqlx-postgres/src/types/json.rs b/sqlx-postgres/src/types/json.rs index 60ebe4138..567e48015 100644 --- a/sqlx-postgres/src/types/json.rs +++ b/sqlx-postgres/src/types/json.rs @@ -58,7 +58,7 @@ impl<'q, T> Encode<'q, Postgres> for Json where T: Serialize, { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { // we have a tiny amount of dynamic behavior depending if we are resolved to be JSON // instead of JSONB buf.patch(|buf, ty: &PgTypeInfo| { @@ -71,10 +71,9 @@ where buf.push(1); // the JSON data written to the buffer is the same regardless of parameter type - serde_json::to_writer(&mut **buf, &self.0) - .expect("failed to serialize to JSON for encoding on transmission to the database"); + serde_json::to_writer(&mut **buf, &self.0)?; - IsNull::No + Ok(IsNull::No) } } diff --git a/sqlx-postgres/src/types/lquery.rs b/sqlx-postgres/src/types/lquery.rs index 79e4d162f..5ca9ff70c 100644 --- a/sqlx-postgres/src/types/lquery.rs +++ b/sqlx-postgres/src/types/lquery.rs @@ -139,12 +139,11 @@ impl Type for PgLQuery { } impl Encode<'_, Postgres> for PgLQuery { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { buf.extend(1i8.to_le_bytes()); - write!(buf, "{self}") - .expect("Display implementation panicked while writing to PgArgumentBuffer"); + write!(buf, "{self}")?; - IsNull::No + Ok(IsNull::No) } } diff --git a/sqlx-postgres/src/types/ltree.rs b/sqlx-postgres/src/types/ltree.rs index 7ddfb3980..cee2cbe49 100644 --- a/sqlx-postgres/src/types/ltree.rs +++ b/sqlx-postgres/src/types/ltree.rs @@ -181,12 +181,11 @@ impl PgHasArrayType for PgLTree { } impl Encode<'_, Postgres> for PgLTree { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { buf.extend(1i8.to_le_bytes()); - write!(buf, "{self}") - .expect("Display implementation panicked while writing to PgArgumentBuffer"); + write!(buf, "{self}")?; - IsNull::No + Ok(IsNull::No) } } diff --git a/sqlx-postgres/src/types/mac_address.rs b/sqlx-postgres/src/types/mac_address.rs index 8038159c8..23766e700 100644 --- a/sqlx-postgres/src/types/mac_address.rs +++ b/sqlx-postgres/src/types/mac_address.rs @@ -23,9 +23,9 @@ impl PgHasArrayType for MacAddress { } impl Encode<'_, Postgres> for MacAddress { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { buf.extend_from_slice(&self.bytes()); // write just the address - IsNull::No + Ok(IsNull::No) } fn size_hint(&self) -> usize { diff --git a/sqlx-postgres/src/types/money.rs b/sqlx-postgres/src/types/money.rs index 1ae3e0a97..45f4bdd83 100644 --- a/sqlx-postgres/src/types/money.rs +++ b/sqlx-postgres/src/types/money.rs @@ -165,10 +165,10 @@ where } impl Encode<'_, Postgres> for PgMoney { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { buf.extend(&self.0.to_be_bytes()); - IsNull::No + Ok(IsNull::No) } } diff --git a/sqlx-postgres/src/types/oid.rs b/sqlx-postgres/src/types/oid.rs index 9841b6034..caa90dfcc 100644 --- a/sqlx-postgres/src/types/oid.rs +++ b/sqlx-postgres/src/types/oid.rs @@ -36,10 +36,10 @@ impl PgHasArrayType for Oid { } impl Encode<'_, Postgres> for Oid { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { buf.extend(&self.0.to_be_bytes()); - IsNull::No + Ok(IsNull::No) } } diff --git a/sqlx-postgres/src/types/range.rs b/sqlx-postgres/src/types/range.rs index ea3457706..bdcd2ce61 100644 --- a/sqlx-postgres/src/types/range.rs +++ b/sqlx-postgres/src/types/range.rs @@ -292,7 +292,7 @@ impl<'q, T> Encode<'q, Postgres> for PgRange where T: Encode<'q, Postgres>, { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { // https://github.com/postgres/postgres/blob/2f48ede080f42b97b594fb14102c82ca1001b80c/src/backend/utils/adt/rangetypes.c#L245 let mut flags = RangeFlags::empty(); @@ -312,15 +312,15 @@ where buf.push(flags.bits()); if let Bound::Included(v) | Bound::Excluded(v) = &self.start { - buf.encode(v); + buf.encode(v)?; } if let Bound::Included(v) | Bound::Excluded(v) = &self.end { - buf.encode(v); + buf.encode(v)?; } // ranges are themselves never null - IsNull::No + Ok(IsNull::No) } } diff --git a/sqlx-postgres/src/types/record.rs b/sqlx-postgres/src/types/record.rs index 18dc395bf..a119f29dc 100644 --- a/sqlx-postgres/src/types/record.rs +++ b/sqlx-postgres/src/types/record.rs @@ -34,7 +34,7 @@ impl<'a> PgRecordEncoder<'a> { } #[doc(hidden)] - pub fn encode<'q, T>(&mut self, value: T) -> &mut Self + pub fn encode<'q, T>(&mut self, value: T) -> Result<&mut Self, BoxDynError> where 'a: 'q, T: Encode<'q, Postgres> + Type, @@ -50,10 +50,10 @@ impl<'a> PgRecordEncoder<'a> { self.buf.extend(&ty.0.oid().0.to_be_bytes()); } - self.buf.encode(value); + self.buf.encode(value)?; self.num += 1; - self + Ok(self) } } diff --git a/sqlx-postgres/src/types/rust_decimal.rs b/sqlx-postgres/src/types/rust_decimal.rs index 9f00e3ca3..5d749f16d 100644 --- a/sqlx-postgres/src/types/rust_decimal.rs +++ b/sqlx-postgres/src/types/rust_decimal.rs @@ -72,18 +72,16 @@ impl TryFrom for Decimal { } // This impl is effectively infallible because `NUMERIC` has a greater range than `Decimal`. -impl TryFrom<&'_ Decimal> for PgNumeric { - type Error = BoxDynError; - - fn try_from(decimal: &Decimal) -> Result { +impl From<&'_ Decimal> for PgNumeric { + fn from(decimal: &Decimal) -> Self { // `Decimal` added `is_zero()` as an inherent method in a more recent version if Zero::is_zero(decimal) { - return Ok(PgNumeric::Number { + PgNumeric::Number { sign: PgNumericSign::Positive, scale: 0, weight: 0, digits: vec![], - }); + }; } let scale = decimal.scale() as u16; @@ -131,7 +129,7 @@ impl TryFrom<&'_ Decimal> for PgNumeric { digits.pop(); } - Ok(PgNumeric::Number { + PgNumeric::Number { sign: match decimal.is_sign_negative() { false => PgNumericSign::Positive, true => PgNumericSign::Negative, @@ -139,17 +137,15 @@ impl TryFrom<&'_ Decimal> for PgNumeric { scale: scale as i16, weight, digits, - }) + } } } impl Encode<'_, Postgres> for Decimal { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { - PgNumeric::try_from(self) - .expect("BUG: `Decimal` to `PgNumeric` conversion should be infallible") - .encode(buf); + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + PgNumeric::from(self).encode(buf); - IsNull::No + Ok(IsNull::No) } } diff --git a/sqlx-postgres/src/types/str.rs b/sqlx-postgres/src/types/str.rs index c6938010e..e3240e6a0 100644 --- a/sqlx-postgres/src/types/str.rs +++ b/sqlx-postgres/src/types/str.rs @@ -95,15 +95,15 @@ impl PgHasArrayType for String { } impl Encode<'_, Postgres> for &'_ str { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { buf.extend(self.as_bytes()); - IsNull::No + Ok(IsNull::No) } } impl Encode<'_, Postgres> for Cow<'_, str> { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { match self { Cow::Borrowed(str) => <&str as Encode>::encode(*str, buf), Cow::Owned(str) => <&str as Encode>::encode(&**str, buf), @@ -112,13 +112,13 @@ impl Encode<'_, Postgres> for Cow<'_, str> { } impl Encode<'_, Postgres> for Box { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { <&str as Encode>::encode(&**self, buf) } } impl Encode<'_, Postgres> for String { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { <&str as Encode>::encode(&**self, buf) } } diff --git a/sqlx-postgres/src/types/text.rs b/sqlx-postgres/src/types/text.rs index 7e96d03f2..b5b0a5ed7 100644 --- a/sqlx-postgres/src/types/text.rs +++ b/sqlx-postgres/src/types/text.rs @@ -22,19 +22,9 @@ impl<'q, T> Encode<'q, Postgres> for Text where T: Display, { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { - // Unfortunately, our API design doesn't give us a way to bubble up the error here. - // - // Fortunately, writing to `Vec` is infallible so the only possible source of - // errors is from the implementation of `Display::fmt()` itself, - // where the onus is on the user. - // - // The blanket impl of `ToString` also panics if there's an error, so this is not - // unprecedented. - // - // However, the panic should be documented anyway. - write!(**buf, "{}", self.0).expect("unexpected error from `Display::fmt()`"); - IsNull::No + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + write!(**buf, "{}", self.0)?; + Ok(IsNull::No) } } diff --git a/sqlx-postgres/src/types/time/date.rs b/sqlx-postgres/src/types/time/date.rs index a320afd72..dfd603dbe 100644 --- a/sqlx-postgres/src/types/time/date.rs +++ b/sqlx-postgres/src/types/time/date.rs @@ -21,7 +21,7 @@ impl PgHasArrayType for Date { } impl Encode<'_, Postgres> for Date { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { // DATE is encoded as the days since epoch let days = (*self - PG_EPOCH).whole_days() as i32; Encode::::encode(&days, buf) diff --git a/sqlx-postgres/src/types/time/datetime.rs b/sqlx-postgres/src/types/time/datetime.rs index 90109078d..5dc696048 100644 --- a/sqlx-postgres/src/types/time/datetime.rs +++ b/sqlx-postgres/src/types/time/datetime.rs @@ -35,7 +35,7 @@ impl PgHasArrayType for OffsetDateTime { } impl Encode<'_, Postgres> for PrimitiveDateTime { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { // TIMESTAMP is encoded as the microseconds since the epoch let us = (*self - PG_EPOCH.midnight()).whole_microseconds() as i64; Encode::::encode(&us, buf) @@ -84,7 +84,7 @@ impl<'r> Decode<'r, Postgres> for PrimitiveDateTime { } impl Encode<'_, Postgres> for OffsetDateTime { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { let utc = self.to_offset(offset!(UTC)); let primitive = PrimitiveDateTime::new(utc.date(), utc.time()); diff --git a/sqlx-postgres/src/types/time/time.rs b/sqlx-postgres/src/types/time/time.rs index 7a9b99d9d..9b1b496d4 100644 --- a/sqlx-postgres/src/types/time/time.rs +++ b/sqlx-postgres/src/types/time/time.rs @@ -20,7 +20,7 @@ impl PgHasArrayType for Time { } impl Encode<'_, Postgres> for Time { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { // TIME is encoded as the microseconds since midnight let us = (*self - Time::MIDNIGHT).whole_microseconds() as i64; Encode::::encode(&us, buf) diff --git a/sqlx-postgres/src/types/time_tz.rs b/sqlx-postgres/src/types/time_tz.rs index 5855a0b7d..33ae6944f 100644 --- a/sqlx-postgres/src/types/time_tz.rs +++ b/sqlx-postgres/src/types/time_tz.rs @@ -51,11 +51,12 @@ mod chrono { } impl Encode<'_, Postgres> for PgTimeTz { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { - let _ = >::encode(self.time, buf); - let _ = >::encode(self.offset.utc_minus_local(), buf); + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + let _: IsNull = >::encode(self.time, buf)?; + let _: IsNull = + >::encode(self.offset.utc_minus_local(), buf)?; - IsNull::No + Ok(IsNull::No) } fn size_hint(&self) -> usize { @@ -134,11 +135,12 @@ mod time { } impl Encode<'_, Postgres> for PgTimeTz { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { - let _ =