From 8dcaa039c80351785f2e79b81c81c8c8f705fc2c Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Fri, 16 Apr 2021 00:39:21 -0700 Subject: [PATCH] feat(mysql): disable parameter type check on older MySQL, add support for NULL --- sqlx-core/src/arguments.rs | 4 +-- sqlx-core/src/encode.rs | 22 ++++++++++++-- sqlx-core/src/lib.rs | 2 ++ sqlx-core/src/null.rs | 50 ++++++++++++++++++++++++++++++++ sqlx-core/src/row.rs | 2 +- sqlx-mysql/src/protocol/row.rs | 4 +-- sqlx-mysql/src/type_id.rs | 12 +++++++- sqlx-mysql/src/type_info.rs | 2 +- sqlx-mysql/src/types.rs | 1 + sqlx-mysql/src/types/bool.rs | 2 +- sqlx-mysql/src/types/bytes.rs | 8 +++--- sqlx-mysql/src/types/int.rs | 52 +++++++++++++--------------------- sqlx-mysql/src/types/null.rs | 24 ++++++++++++++++ sqlx-mysql/src/types/str.rs | 8 +++--- sqlx-mysql/src/types/uint.rs | 13 +++++---- 15 files changed, 149 insertions(+), 57 deletions(-) create mode 100644 sqlx-core/src/null.rs create mode 100644 sqlx-mysql/src/types/null.rs diff --git a/sqlx-core/src/arguments.rs b/sqlx-core/src/arguments.rs index 499c747c..0a3d69a9 100644 --- a/sqlx-core/src/arguments.rs +++ b/sqlx-core/src/arguments.rs @@ -173,8 +173,8 @@ impl<'a, Db: Database> Argument<'a, Db> { &self, ty: &Db::TypeInfo, out: &mut >::Output, - ) -> Result<()> { - let res = if !self.unchecked && !(self.type_compatible)(ty) { + ) -> Result { + let res = if !self.unchecked && !ty.is_unknown() && !(self.type_compatible)(ty) { Err(encode::Error::TypeNotCompatible { rust_type_name: self.rust_type_name, sql_type_name: ty.name(), diff --git a/sqlx-core/src/encode.rs b/sqlx-core/src/encode.rs index bd4c9381..ab7cd968 100644 --- a/sqlx-core/src/encode.rs +++ b/sqlx-core/src/encode.rs @@ -4,15 +4,31 @@ use std::fmt::{self, Display, Formatter}; use crate::database::HasOutput; use crate::Database; +/// Type returned from [`Encode::encode`] that indicates if the value encoded is the SQL `null` or not. +pub enum IsNull { + /// The value is the SQL `null`. + /// + /// No data was written to the output buffer. + /// + Yes, + + /// The value is not the SQL `null`. + /// + /// This does not mean that any data was written to the output buffer. For example, + /// an empty string has no data, but is not null. + /// + No, +} + /// A type that can be encoded into a SQL value. pub trait Encode: Send + Sync { /// Encode this value into the specified SQL type. - fn encode(&self, ty: &Db::TypeInfo, out: &mut >::Output) -> Result<()>; + fn encode(&self, ty: &Db::TypeInfo, out: &mut >::Output) -> Result; } impl, Db: Database> Encode for &T { #[inline] - fn encode(&self, ty: &Db::TypeInfo, out: &mut >::Output) -> Result<()> { + fn encode(&self, ty: &Db::TypeInfo, out: &mut >::Output) -> Result { (*self).encode(ty, out) } } @@ -63,4 +79,4 @@ impl From for Error { } /// A specialized result type representing the result of encoding a SQL value. -pub type Result = std::result::Result; +pub type Result = std::result::Result; diff --git a/sqlx-core/src/lib.rs b/sqlx-core/src/lib.rs index ff8b5237..84bae326 100644 --- a/sqlx-core/src/lib.rs +++ b/sqlx-core/src/lib.rs @@ -34,6 +34,7 @@ mod execute; mod executor; mod from_row; mod isolation_level; +mod null; mod options; mod query_result; mod raw_value; @@ -72,6 +73,7 @@ pub use execute::Execute; pub use executor::Executor; pub use from_row::FromRow; pub use isolation_level::IsolationLevel; +pub use null::Null; pub use options::ConnectOptions; pub use query_result::QueryResult; pub use r#type::{Type, TypeDecode, TypeDecodeOwned, TypeEncode}; diff --git a/sqlx-core/src/null.rs b/sqlx-core/src/null.rs new file mode 100644 index 00000000..59614609 --- /dev/null +++ b/sqlx-core/src/null.rs @@ -0,0 +1,50 @@ +use crate::database::{HasOutput, HasRawValue}; +use crate::{decode, encode, Database, Decode, Encode, RawValue, Type}; +use std::ops::Not; + +#[derive(Debug)] +pub struct Null; + +impl> Type for Option +where + Null: Type, +{ + fn type_id() -> ::TypeId + where + Self: Sized, + { + T::type_id() + } + + fn compatible(ty: &::TypeInfo) -> bool + where + Self: Sized, + { + T::compatible(ty) + } +} + +impl> Encode for Option +where + Null: Encode, +{ + fn encode( + &self, + ty: &::TypeInfo, + out: &mut >::Output, + ) -> encode::Result { + match self { + Some(value) => value.encode(ty, out), + None => Null.encode(ty, out), + } + } +} + +impl<'r, Db: Database, T: Decode<'r, Db>> Decode<'r, Db> for Option +where + Null: Decode<'r, Db>, +{ + fn decode(value: >::RawValue) -> decode::Result { + value.is_null().not().then(|| T::decode(value)).transpose() + } +} diff --git a/sqlx-core/src/row.rs b/sqlx-core/src/row.rs index 3a1b7df4..4bfccb8b 100644 --- a/sqlx-core/src/row.rs +++ b/sqlx-core/src/row.rs @@ -92,7 +92,7 @@ pub trait Row: 'static + Send + Sync { { let value = self.try_get_raw(&index)?; - let res = if !T::compatible(value.type_info()) { + let res = if !value.is_null() && !T::compatible(value.type_info()) { Err(decode::Error::TypeNotCompatible { rust_type_name: any::type_name::(), sql_type_name: value.type_info().name(), diff --git a/sqlx-mysql/src/protocol/row.rs b/sqlx-mysql/src/protocol/row.rs index 039b6609..5bf7b075 100644 --- a/sqlx-mysql/src/protocol/row.rs +++ b/sqlx-mysql/src/protocol/row.rs @@ -40,7 +40,7 @@ impl<'de> Deserialize<'de, (MySqlRawValueFormat, &'de [MySqlColumn])> for Row { // [0x00] packer header let header = buf.get_u8(); - assert!(header == 0x00); + assert_eq!(header, 0x00); // NULL bit map let null = buf.split_to((columns.len() + 9) / 8); @@ -49,7 +49,7 @@ impl<'de> Deserialize<'de, (MySqlRawValueFormat, &'de [MySqlColumn])> for Row { // NULL columns are marked in the bitmap and are not in this list for (i, col) in columns.iter().enumerate() { // NOTE: the column index starts at the 3rd bit - let null_i = i + 3; + let null_i = i + 2; let is_null = null[null_i / 8] & (1 << (null_i % 8) as u8) != 0; if is_null { diff --git a/sqlx-mysql/src/type_id.rs b/sqlx-mysql/src/type_id.rs index 0e83701a..e8d2da8a 100644 --- a/sqlx-mysql/src/type_id.rs +++ b/sqlx-mysql/src/type_id.rs @@ -13,7 +13,17 @@ pub struct MySqlTypeId(u8, u8); const UNSIGNED: u8 = 0x80; impl MySqlTypeId { - pub(crate) const fn new(def: &ColumnDefinition) -> Self { + pub(crate) fn new(def: &ColumnDefinition) -> Self { + if def.schema.is_empty() + && def.ty == Self::VARCHAR.0 + && def.flags.contains(ColumnFlags::BINARY_COLLATION) + { + // older MySQL typed every parameter as VARBINARY + // this will pick it up and emit a NULL type so we don't + // try and do strong type checking on parameters + return Self::NULL; + } + Self(def.ty, if def.flags.contains(ColumnFlags::UNSIGNED) { UNSIGNED } else { 0 }) } diff --git a/sqlx-mysql/src/type_info.rs b/sqlx-mysql/src/type_info.rs index 5aff13bc..6c07be93 100644 --- a/sqlx-mysql/src/type_info.rs +++ b/sqlx-mysql/src/type_info.rs @@ -20,7 +20,7 @@ pub struct MySqlTypeInfo { } impl MySqlTypeInfo { - pub(crate) const fn new(def: &ColumnDefinition) -> Self { + pub(crate) fn new(def: &ColumnDefinition) -> Self { Self { id: MySqlTypeId::new(def), charset: def.charset, diff --git a/sqlx-mysql/src/types.rs b/sqlx-mysql/src/types.rs index 026ee4e5..58b4a074 100644 --- a/sqlx-mysql/src/types.rs +++ b/sqlx-mysql/src/types.rs @@ -123,6 +123,7 @@ mod bool; mod bytes; mod int; +mod null; mod str; mod uint; diff --git a/sqlx-mysql/src/types/bool.rs b/sqlx-mysql/src/types/bool.rs index 78eb3d3f..5955de53 100644 --- a/sqlx-mysql/src/types/bool.rs +++ b/sqlx-mysql/src/types/bool.rs @@ -16,7 +16,7 @@ impl Type for bool { } impl Encode for bool { - fn encode(&self, ty: &MySqlTypeInfo, out: &mut MySqlOutput<'_>) -> encode::Result<()> { + fn encode(&self, ty: &MySqlTypeInfo, out: &mut MySqlOutput<'_>) -> encode::Result { >::encode(&(*self as i128), ty, out) } } diff --git a/sqlx-mysql/src/types/bytes.rs b/sqlx-mysql/src/types/bytes.rs index 89a364c4..32fd0008 100644 --- a/sqlx-mysql/src/types/bytes.rs +++ b/sqlx-mysql/src/types/bytes.rs @@ -25,10 +25,10 @@ impl Type for &'_ [u8] { } impl Encode for &'_ [u8] { - fn encode(&self, _: &MySqlTypeInfo, out: &mut MySqlOutput<'_>) -> encode::Result<()> { + fn encode(&self, _: &MySqlTypeInfo, out: &mut MySqlOutput<'_>) -> encode::Result { out.buffer().write_bytes_lenenc(self); - Ok(()) + Ok(encode::IsNull::No) } } @@ -49,7 +49,7 @@ impl Type for Vec { } impl Encode for Vec { - fn encode(&self, ty: &MySqlTypeInfo, out: &mut MySqlOutput<'_>) -> encode::Result<()> { + fn encode(&self, ty: &MySqlTypeInfo, out: &mut MySqlOutput<'_>) -> encode::Result { <&[u8] as Encode>::encode(&self.as_slice(), ty, out) } } @@ -71,7 +71,7 @@ impl Type for Bytes { } impl Encode for Bytes { - fn encode(&self, ty: &MySqlTypeInfo, out: &mut MySqlOutput<'_>) -> encode::Result<()> { + fn encode(&self, ty: &MySqlTypeInfo, out: &mut MySqlOutput<'_>) -> encode::Result { <&[u8] as Encode>::encode(&&**self, ty, out) } } diff --git a/sqlx-mysql/src/types/int.rs b/sqlx-mysql/src/types/int.rs index 0bb03e09..6dda0fbc 100644 --- a/sqlx-mysql/src/types/int.rs +++ b/sqlx-mysql/src/types/int.rs @@ -6,39 +6,25 @@ use crate::{MySql, MySqlOutput, MySqlRawValue, MySqlTypeId}; // check that the incoming value is not too large or too small // to fit into the target SQL type -fn ensure_not_too_large_or_too_small(value: i128, ty: &MySqlTypeInfo) -> encode::Result<()> { - let max: i128 = match ty.id() { - MySqlTypeId::TINYINT => i8::MAX as _, - MySqlTypeId::SMALLINT => i16::MAX as _, - MySqlTypeId::MEDIUMINT => 0x7F_FF_FF as _, - MySqlTypeId::INT => i32::MAX as _, - MySqlTypeId::BIGINT => i64::MAX as _, +fn ensure_not_too_large_or_too_small(value: i128, ty: &MySqlTypeInfo) -> Result<(), encode::Error> { + let (max, min): (i128, i128) = match ty.id() { + MySqlTypeId::TINYINT => (i8::MAX as _, i8::MIN as _), + MySqlTypeId::SMALLINT => (i16::MAX as _, i16::MIN as _), + MySqlTypeId::MEDIUMINT => (0x7F_FF_FF as _, 0x80_00_00 as _), + MySqlTypeId::INT => (i32::MAX as _, i32::MIN as _), + MySqlTypeId::BIGINT => (i64::MAX as _, i64::MIN as _), - MySqlTypeId::TINYINT_UNSIGNED => u8::MAX as _, - MySqlTypeId::SMALLINT_UNSIGNED => u16::MAX as _, - MySqlTypeId::MEDIUMINT_UNSIGNED => 0xFF_FF_FF as _, - MySqlTypeId::INT_UNSIGNED => u32::MAX as _, - MySqlTypeId::BIGINT_UNSIGNED => u64::MAX as _, + MySqlTypeId::TINYINT_UNSIGNED => (u8::MAX as _, u8::MIN as _), + MySqlTypeId::SMALLINT_UNSIGNED => (u16::MAX as _, u16::MIN as _), + MySqlTypeId::MEDIUMINT_UNSIGNED => (0xFF_FF_FF as _, 0 as _), + MySqlTypeId::INT_UNSIGNED => (u32::MAX as _, u32::MIN as _), + MySqlTypeId::BIGINT_UNSIGNED => (u64::MAX as _, u64::MIN as _), - // not an integer type - _ => unreachable!(), - }; - - let min: i128 = match ty.id() { - MySqlTypeId::TINYINT => i8::MIN as _, - MySqlTypeId::SMALLINT => i16::MIN as _, - MySqlTypeId::MEDIUMINT => 0x80_00_00 as _, - MySqlTypeId::INT => i32::MIN as _, - MySqlTypeId::BIGINT => i64::MIN as _, - - MySqlTypeId::TINYINT_UNSIGNED => u8::MIN as _, - MySqlTypeId::SMALLINT_UNSIGNED => u16::MIN as _, - MySqlTypeId::MEDIUMINT_UNSIGNED => 0 as _, - MySqlTypeId::INT_UNSIGNED => u32::MIN as _, - MySqlTypeId::BIGINT_UNSIGNED => u64::MIN as _, - - // not an integer type - _ => unreachable!(), + // not an integer type, if we got this far its because this is _unchecked + // just let it through + _ => { + return Ok(()); + } }; if value > max { @@ -73,12 +59,12 @@ macro_rules! impl_type_int { } impl Encode for $ty { - fn encode(&self, ty: &MySqlTypeInfo, out: &mut MySqlOutput<'_>) -> encode::Result<()> { + fn encode(&self, ty: &MySqlTypeInfo, out: &mut MySqlOutput<'_>) -> encode::Result { ensure_not_too_large_or_too_small((*self $(as $real)?).into(), ty)?; out.buffer().extend_from_slice(&self.to_le_bytes()); - Ok(()) + Ok(encode::IsNull::No) } } diff --git a/sqlx-mysql/src/types/null.rs b/sqlx-mysql/src/types/null.rs new file mode 100644 index 00000000..88281eb6 --- /dev/null +++ b/sqlx-mysql/src/types/null.rs @@ -0,0 +1,24 @@ +use crate::{MySql, MySqlOutput, MySqlRawValue, MySqlTypeId, MySqlTypeInfo}; +use sqlx_core::database::{HasOutput, HasRawValue}; +use sqlx_core::{decode, encode, Database, Decode, Encode, Null, Type}; + +impl Type for Null { + fn type_id() -> MySqlTypeId + where + Self: Sized, + { + MySqlTypeId::NULL + } +} + +impl Encode for Null { + fn encode(&self, _: &MySqlTypeInfo, _: &mut MySqlOutput<'_>) -> encode::Result { + Ok(encode::IsNull::Yes) + } +} + +impl<'r> Decode<'r, MySql> for Null { + fn decode(_: MySqlRawValue<'r>) -> decode::Result { + Ok(Self) + } +} diff --git a/sqlx-mysql/src/types/str.rs b/sqlx-mysql/src/types/str.rs index a7221c77..68029b86 100644 --- a/sqlx-mysql/src/types/str.rs +++ b/sqlx-mysql/src/types/str.rs @@ -16,10 +16,10 @@ impl Type for &'_ str { } impl Encode for &'_ str { - fn encode(&self, _: &MySqlTypeInfo, out: &mut MySqlOutput<'_>) -> encode::Result<()> { + fn encode(&self, _: &MySqlTypeInfo, out: &mut MySqlOutput<'_>) -> encode::Result { out.buffer().write_bytes_lenenc(self.as_bytes()); - Ok(()) + Ok(encode::IsNull::No) } } @@ -40,7 +40,7 @@ impl Type for String { } impl Encode for String { - fn encode(&self, ty: &MySqlTypeInfo, out: &mut MySqlOutput<'_>) -> encode::Result<()> { + fn encode(&self, ty: &MySqlTypeInfo, out: &mut MySqlOutput<'_>) -> encode::Result { <&str as Encode>::encode(&self.as_str(), ty, out) } } @@ -62,7 +62,7 @@ impl Type for ByteString { } impl Encode for ByteString { - fn encode(&self, ty: &MySqlTypeInfo, out: &mut MySqlOutput<'_>) -> encode::Result<()> { + fn encode(&self, ty: &MySqlTypeInfo, out: &mut MySqlOutput<'_>) -> encode::Result { <&str as Encode>::encode(&&**self, ty, out) } } diff --git a/sqlx-mysql/src/types/uint.rs b/sqlx-mysql/src/types/uint.rs index 143ff332..3ff8d7f1 100644 --- a/sqlx-mysql/src/types/uint.rs +++ b/sqlx-mysql/src/types/uint.rs @@ -33,7 +33,7 @@ where // check that the incoming value is not too large // to fit into the target SQL type -fn ensure_not_too_large(value: u128, ty: &MySqlTypeInfo) -> encode::Result<()> { +fn ensure_not_too_large(value: u128, ty: &MySqlTypeInfo) -> Result<(), encode::Error> { let max = match ty.id() { MySqlTypeId::TINYINT => i8::MAX as _, MySqlTypeId::SMALLINT => i16::MAX as _, @@ -47,8 +47,11 @@ fn ensure_not_too_large(value: u128, ty: &MySqlTypeInfo) -> encode::Result<()> { MySqlTypeId::INT_UNSIGNED => u32::MAX as _, MySqlTypeId::BIGINT_UNSIGNED => u64::MAX as _, - // not an integer type - _ => unreachable!(), + // not an integer type, if we got this far its because this is _unchecked + // just let it through + _ => { + return Ok(()); + } }; if value > max { @@ -75,12 +78,12 @@ macro_rules! impl_type_uint { } impl Encode for $ty { - fn encode(&self, ty: &MySqlTypeInfo, out: &mut MySqlOutput<'_>) -> encode::Result<()> { + fn encode(&self, ty: &MySqlTypeInfo, out: &mut MySqlOutput<'_>) -> encode::Result { ensure_not_too_large((*self $(as $real)?).into(), ty)?; out.buffer().extend_from_slice(&self.to_le_bytes()); - Ok(()) + Ok(encode::IsNull::No) } }