feat(mysql): disable parameter type check on older MySQL, add support for NULL

This commit is contained in:
Ryan Leckey 2021-04-16 00:39:21 -07:00
parent d03a294555
commit 8dcaa039c8
15 changed files with 149 additions and 57 deletions

View File

@ -173,8 +173,8 @@ impl<'a, Db: Database> Argument<'a, Db> {
&self,
ty: &Db::TypeInfo,
out: &mut <Db as HasOutput<'x>>::Output,
) -> Result<()> {
let res = if !self.unchecked && !(self.type_compatible)(ty) {
) -> Result<encode::IsNull> {
let res = if !self.unchecked && !ty.is_unknown() && !(self.type_compatible)(ty) {
Err(encode::Error::TypeNotCompatible {
rust_type_name: self.rust_type_name,
sql_type_name: ty.name(),

View File

@ -4,15 +4,31 @@ use std::fmt::{self, Display, Formatter};
use crate::database::HasOutput;
use crate::Database;
/// Type returned from [`Encode::encode`] that indicates if the value encoded is the SQL `null` or not.
pub enum IsNull {
/// The value is the SQL `null`.
///
/// No data was written to the output buffer.
///
Yes,
/// The value is not the SQL `null`.
///
/// This does not mean that any data was written to the output buffer. For example,
/// an empty string has no data, but is not null.
///
No,
}
/// A type that can be encoded into a SQL value.
pub trait Encode<Db: Database>: Send + Sync {
/// Encode this value into the specified SQL type.
fn encode(&self, ty: &Db::TypeInfo, out: &mut <Db as HasOutput<'_>>::Output) -> Result<()>;
fn encode(&self, ty: &Db::TypeInfo, out: &mut <Db as HasOutput<'_>>::Output) -> Result;
}
impl<T: Encode<Db>, Db: Database> Encode<Db> for &T {
#[inline]
fn encode(&self, ty: &Db::TypeInfo, out: &mut <Db as HasOutput<'_>>::Output) -> Result<()> {
fn encode(&self, ty: &Db::TypeInfo, out: &mut <Db as HasOutput<'_>>::Output) -> Result {
(*self).encode(ty, out)
}
}
@ -63,4 +79,4 @@ impl<E: StdError + Send + Sync + 'static> From<E> for Error {
}
/// A specialized result type representing the result of encoding a SQL value.
pub type Result<T> = std::result::Result<T, Error>;
pub type Result = std::result::Result<IsNull, Error>;

View File

@ -34,6 +34,7 @@ mod execute;
mod executor;
mod from_row;
mod isolation_level;
mod null;
mod options;
mod query_result;
mod raw_value;
@ -72,6 +73,7 @@ pub use execute::Execute;
pub use executor::Executor;
pub use from_row::FromRow;
pub use isolation_level::IsolationLevel;
pub use null::Null;
pub use options::ConnectOptions;
pub use query_result::QueryResult;
pub use r#type::{Type, TypeDecode, TypeDecodeOwned, TypeEncode};

50
sqlx-core/src/null.rs Normal file
View File

@ -0,0 +1,50 @@
use crate::database::{HasOutput, HasRawValue};
use crate::{decode, encode, Database, Decode, Encode, RawValue, Type};
use std::ops::Not;
#[derive(Debug)]
pub struct Null;
impl<Db: Database, T: Type<Db>> Type<Db> for Option<T>
where
Null: Type<Db>,
{
fn type_id() -> <Db as Database>::TypeId
where
Self: Sized,
{
T::type_id()
}
fn compatible(ty: &<Db as Database>::TypeInfo) -> bool
where
Self: Sized,
{
T::compatible(ty)
}
}
impl<Db: Database, T: Encode<Db>> Encode<Db> for Option<T>
where
Null: Encode<Db>,
{
fn encode(
&self,
ty: &<Db as Database>::TypeInfo,
out: &mut <Db as HasOutput<'_>>::Output,
) -> encode::Result {
match self {
Some(value) => value.encode(ty, out),
None => Null.encode(ty, out),
}
}
}
impl<'r, Db: Database, T: Decode<'r, Db>> Decode<'r, Db> for Option<T>
where
Null: Decode<'r, Db>,
{
fn decode(value: <Db as HasRawValue<'r>>::RawValue) -> decode::Result<Self> {
value.is_null().not().then(|| T::decode(value)).transpose()
}
}

View File

@ -92,7 +92,7 @@ pub trait Row: 'static + Send + Sync {
{
let value = self.try_get_raw(&index)?;
let res = if !T::compatible(value.type_info()) {
let res = if !value.is_null() && !T::compatible(value.type_info()) {
Err(decode::Error::TypeNotCompatible {
rust_type_name: any::type_name::<T>(),
sql_type_name: value.type_info().name(),

View File

@ -40,7 +40,7 @@ impl<'de> Deserialize<'de, (MySqlRawValueFormat, &'de [MySqlColumn])> for Row {
// [0x00] packer header
let header = buf.get_u8();
assert!(header == 0x00);
assert_eq!(header, 0x00);
// NULL bit map
let null = buf.split_to((columns.len() + 9) / 8);
@ -49,7 +49,7 @@ impl<'de> Deserialize<'de, (MySqlRawValueFormat, &'de [MySqlColumn])> for Row {
// NULL columns are marked in the bitmap and are not in this list
for (i, col) in columns.iter().enumerate() {
// NOTE: the column index starts at the 3rd bit
let null_i = i + 3;
let null_i = i + 2;
let is_null = null[null_i / 8] & (1 << (null_i % 8) as u8) != 0;
if is_null {

View File

@ -13,7 +13,17 @@ pub struct MySqlTypeId(u8, u8);
const UNSIGNED: u8 = 0x80;
impl MySqlTypeId {
pub(crate) const fn new(def: &ColumnDefinition) -> Self {
pub(crate) fn new(def: &ColumnDefinition) -> Self {
if def.schema.is_empty()
&& def.ty == Self::VARCHAR.0
&& def.flags.contains(ColumnFlags::BINARY_COLLATION)
{
// older MySQL typed every parameter as VARBINARY
// this will pick it up and emit a NULL type so we don't
// try and do strong type checking on parameters
return Self::NULL;
}
Self(def.ty, if def.flags.contains(ColumnFlags::UNSIGNED) { UNSIGNED } else { 0 })
}

View File

@ -20,7 +20,7 @@ pub struct MySqlTypeInfo {
}
impl MySqlTypeInfo {
pub(crate) const fn new(def: &ColumnDefinition) -> Self {
pub(crate) fn new(def: &ColumnDefinition) -> Self {
Self {
id: MySqlTypeId::new(def),
charset: def.charset,

View File

@ -123,6 +123,7 @@
mod bool;
mod bytes;
mod int;
mod null;
mod str;
mod uint;

View File

@ -16,7 +16,7 @@ impl Type<MySql> for bool {
}
impl Encode<MySql> for bool {
fn encode(&self, ty: &MySqlTypeInfo, out: &mut MySqlOutput<'_>) -> encode::Result<()> {
fn encode(&self, ty: &MySqlTypeInfo, out: &mut MySqlOutput<'_>) -> encode::Result {
<i128 as Encode<MySql>>::encode(&(*self as i128), ty, out)
}
}

View File

@ -25,10 +25,10 @@ impl Type<MySql> for &'_ [u8] {
}
impl Encode<MySql> for &'_ [u8] {
fn encode(&self, _: &MySqlTypeInfo, out: &mut MySqlOutput<'_>) -> encode::Result<()> {
fn encode(&self, _: &MySqlTypeInfo, out: &mut MySqlOutput<'_>) -> encode::Result {
out.buffer().write_bytes_lenenc(self);
Ok(())
Ok(encode::IsNull::No)
}
}
@ -49,7 +49,7 @@ impl Type<MySql> for Vec<u8> {
}
impl Encode<MySql> for Vec<u8> {
fn encode(&self, ty: &MySqlTypeInfo, out: &mut MySqlOutput<'_>) -> encode::Result<()> {
fn encode(&self, ty: &MySqlTypeInfo, out: &mut MySqlOutput<'_>) -> encode::Result {
<&[u8] as Encode<MySql>>::encode(&self.as_slice(), ty, out)
}
}
@ -71,7 +71,7 @@ impl Type<MySql> for Bytes {
}
impl Encode<MySql> for Bytes {
fn encode(&self, ty: &MySqlTypeInfo, out: &mut MySqlOutput<'_>) -> encode::Result<()> {
fn encode(&self, ty: &MySqlTypeInfo, out: &mut MySqlOutput<'_>) -> encode::Result {
<&[u8] as Encode<MySql>>::encode(&&**self, ty, out)
}
}

View File

@ -6,39 +6,25 @@ use crate::{MySql, MySqlOutput, MySqlRawValue, MySqlTypeId};
// check that the incoming value is not too large or too small
// to fit into the target SQL type
fn ensure_not_too_large_or_too_small(value: i128, ty: &MySqlTypeInfo) -> encode::Result<()> {
let max: i128 = match ty.id() {
MySqlTypeId::TINYINT => i8::MAX as _,
MySqlTypeId::SMALLINT => i16::MAX as _,
MySqlTypeId::MEDIUMINT => 0x7F_FF_FF as _,
MySqlTypeId::INT => i32::MAX as _,
MySqlTypeId::BIGINT => i64::MAX as _,
fn ensure_not_too_large_or_too_small(value: i128, ty: &MySqlTypeInfo) -> Result<(), encode::Error> {
let (max, min): (i128, i128) = match ty.id() {
MySqlTypeId::TINYINT => (i8::MAX as _, i8::MIN as _),
MySqlTypeId::SMALLINT => (i16::MAX as _, i16::MIN as _),
MySqlTypeId::MEDIUMINT => (0x7F_FF_FF as _, 0x80_00_00 as _),
MySqlTypeId::INT => (i32::MAX as _, i32::MIN as _),
MySqlTypeId::BIGINT => (i64::MAX as _, i64::MIN as _),
MySqlTypeId::TINYINT_UNSIGNED => u8::MAX as _,
MySqlTypeId::SMALLINT_UNSIGNED => u16::MAX as _,
MySqlTypeId::MEDIUMINT_UNSIGNED => 0xFF_FF_FF as _,
MySqlTypeId::INT_UNSIGNED => u32::MAX as _,
MySqlTypeId::BIGINT_UNSIGNED => u64::MAX as _,
MySqlTypeId::TINYINT_UNSIGNED => (u8::MAX as _, u8::MIN as _),
MySqlTypeId::SMALLINT_UNSIGNED => (u16::MAX as _, u16::MIN as _),
MySqlTypeId::MEDIUMINT_UNSIGNED => (0xFF_FF_FF as _, 0 as _),
MySqlTypeId::INT_UNSIGNED => (u32::MAX as _, u32::MIN as _),
MySqlTypeId::BIGINT_UNSIGNED => (u64::MAX as _, u64::MIN as _),
// not an integer type
_ => unreachable!(),
};
let min: i128 = match ty.id() {
MySqlTypeId::TINYINT => i8::MIN as _,
MySqlTypeId::SMALLINT => i16::MIN as _,
MySqlTypeId::MEDIUMINT => 0x80_00_00 as _,
MySqlTypeId::INT => i32::MIN as _,
MySqlTypeId::BIGINT => i64::MIN as _,
MySqlTypeId::TINYINT_UNSIGNED => u8::MIN as _,
MySqlTypeId::SMALLINT_UNSIGNED => u16::MIN as _,
MySqlTypeId::MEDIUMINT_UNSIGNED => 0 as _,
MySqlTypeId::INT_UNSIGNED => u32::MIN as _,
MySqlTypeId::BIGINT_UNSIGNED => u64::MIN as _,
// not an integer type
_ => unreachable!(),
// not an integer type, if we got this far its because this is _unchecked
// just let it through
_ => {
return Ok(());
}
};
if value > max {
@ -73,12 +59,12 @@ macro_rules! impl_type_int {
}
impl Encode<MySql> for $ty {
fn encode(&self, ty: &MySqlTypeInfo, out: &mut MySqlOutput<'_>) -> encode::Result<()> {
fn encode(&self, ty: &MySqlTypeInfo, out: &mut MySqlOutput<'_>) -> encode::Result {
ensure_not_too_large_or_too_small((*self $(as $real)?).into(), ty)?;
out.buffer().extend_from_slice(&self.to_le_bytes());
Ok(())
Ok(encode::IsNull::No)
}
}

View File

@ -0,0 +1,24 @@
use crate::{MySql, MySqlOutput, MySqlRawValue, MySqlTypeId, MySqlTypeInfo};
use sqlx_core::database::{HasOutput, HasRawValue};
use sqlx_core::{decode, encode, Database, Decode, Encode, Null, Type};
impl Type<MySql> for Null {
fn type_id() -> MySqlTypeId
where
Self: Sized,
{
MySqlTypeId::NULL
}
}
impl Encode<MySql> for Null {
fn encode(&self, _: &MySqlTypeInfo, _: &mut MySqlOutput<'_>) -> encode::Result {
Ok(encode::IsNull::Yes)
}
}
impl<'r> Decode<'r, MySql> for Null {
fn decode(_: MySqlRawValue<'r>) -> decode::Result<Self> {
Ok(Self)
}
}

View File

@ -16,10 +16,10 @@ impl Type<MySql> for &'_ str {
}
impl Encode<MySql> for &'_ str {
fn encode(&self, _: &MySqlTypeInfo, out: &mut MySqlOutput<'_>) -> encode::Result<()> {
fn encode(&self, _: &MySqlTypeInfo, out: &mut MySqlOutput<'_>) -> encode::Result {
out.buffer().write_bytes_lenenc(self.as_bytes());
Ok(())
Ok(encode::IsNull::No)
}
}
@ -40,7 +40,7 @@ impl Type<MySql> for String {
}
impl Encode<MySql> for String {
fn encode(&self, ty: &MySqlTypeInfo, out: &mut MySqlOutput<'_>) -> encode::Result<()> {
fn encode(&self, ty: &MySqlTypeInfo, out: &mut MySqlOutput<'_>) -> encode::Result {
<&str as Encode<MySql>>::encode(&self.as_str(), ty, out)
}
}
@ -62,7 +62,7 @@ impl Type<MySql> for ByteString {
}
impl Encode<MySql> for ByteString {
fn encode(&self, ty: &MySqlTypeInfo, out: &mut MySqlOutput<'_>) -> encode::Result<()> {
fn encode(&self, ty: &MySqlTypeInfo, out: &mut MySqlOutput<'_>) -> encode::Result {
<&str as Encode<MySql>>::encode(&&**self, ty, out)
}
}

View File

@ -33,7 +33,7 @@ where
// check that the incoming value is not too large
// to fit into the target SQL type
fn ensure_not_too_large(value: u128, ty: &MySqlTypeInfo) -> encode::Result<()> {
fn ensure_not_too_large(value: u128, ty: &MySqlTypeInfo) -> Result<(), encode::Error> {
let max = match ty.id() {
MySqlTypeId::TINYINT => i8::MAX as _,
MySqlTypeId::SMALLINT => i16::MAX as _,
@ -47,8 +47,11 @@ fn ensure_not_too_large(value: u128, ty: &MySqlTypeInfo) -> encode::Result<()> {
MySqlTypeId::INT_UNSIGNED => u32::MAX as _,
MySqlTypeId::BIGINT_UNSIGNED => u64::MAX as _,
// not an integer type
_ => unreachable!(),
// not an integer type, if we got this far its because this is _unchecked
// just let it through
_ => {
return Ok(());
}
};
if value > max {
@ -75,12 +78,12 @@ macro_rules! impl_type_uint {
}
impl Encode<MySql> for $ty {
fn encode(&self, ty: &MySqlTypeInfo, out: &mut MySqlOutput<'_>) -> encode::Result<()> {
fn encode(&self, ty: &MySqlTypeInfo, out: &mut MySqlOutput<'_>) -> encode::Result {
ensure_not_too_large((*self $(as $real)?).into(), ty)?;
out.buffer().extend_from_slice(&self.to_le_bytes());
Ok(())
Ok(encode::IsNull::No)
}
}