From c9252570a954d4853a27cb3524d24b2c21ab7ebd Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Fri, 26 Feb 2021 00:20:38 -0800 Subject: [PATCH] feat(mysql): impl Type for u8, u16, u32, and u64 --- sqlx-mysql/src/types/uint.rs | 97 ++++++++++++++++++++++-------------- 1 file changed, 59 insertions(+), 38 deletions(-) diff --git a/sqlx-mysql/src/types/uint.rs b/sqlx-mysql/src/types/uint.rs index 26fa224e..dea21d3b 100644 --- a/sqlx-mysql/src/types/uint.rs +++ b/sqlx-mysql/src/types/uint.rs @@ -10,16 +10,12 @@ use sqlx_core::{Decode, Encode, Type}; use crate::type_info::MySqlTypeInfo; use crate::{MySql, MySqlOutput, MySqlRawValue, MySqlRawValueFormat, MySqlTypeId}; -// https://dev.mysql.com/doc/internals/en/binary-protocol-value.html#packet-ProtocolBinary - -const NUMBER_TOO_LARGE: &str = "number too large to fit in target type"; - -// shared among all Decode impls for unsigned and signed integers -fn decode_int_or_uint(value: &MySqlRawValue<'_>) -> decode::Result +pub(super) fn decode_int_or_uint(value: &MySqlRawValue<'_>) -> decode::Result where - T: TryFrom + TryFrom + FromStr, - >::Error: 'static + StdError + Send + Sync, + T: TryFrom + FromStr, + T: TryFrom + FromStr, >::Error: 'static + StdError + Send + Sync, + >::Error: 'static + StdError + Send + Sync, ::Err: 'static + StdError + Send + Sync, { if value.format() == MySqlRawValueFormat::Text { @@ -27,51 +23,76 @@ where } let mut bytes = value.as_bytes()?; - - // start from u64 if the value is marked as unsigned - // otherwise start from i64 let is_unsigned = value.type_info().id().is_unsigned(); - - // pull at most 8 bytes from the buffer - let len = cmp::max(bytes.len(), 8); + let size = cmp::max(bytes.len(), 8); Ok(if is_unsigned { - bytes.get_uint_le(len).try_into()? + bytes.get_uint_le(size).try_into()? } else { - bytes.get_int_le(len).try_into()? + bytes.get_int_le(size).try_into()? }) } -impl Type for u8 { - fn type_id() -> MySqlTypeId { - MySqlTypeId::TINYINT_UNSIGNED +// check that the incoming value is not too large +// to fit into the target SQL type +fn ensure_not_too_large(value: u64, ty: &MySqlTypeInfo) -> encode::Result<()> { + let max = match ty.id() { + MySqlTypeId::TINYINT => i8::MAX as _, + MySqlTypeId::SMALLINT => i16::MAX as _, + MySqlTypeId::MEDIUMINT => 0x7F_FF_FF as _, + MySqlTypeId::INT => i32::MAX as _, + MySqlTypeId::BIGINT => i64::MAX as _, + + MySqlTypeId::TINYINT_UNSIGNED => u8::MAX as _, + MySqlTypeId::SMALLINT_UNSIGNED => u16::MAX as _, + MySqlTypeId::MEDIUMINT_UNSIGNED => 0xFF_FF_FF as _, + MySqlTypeId::INT_UNSIGNED => u32::MAX as _, + MySqlTypeId::BIGINT_UNSIGNED => u64::MAX, + + // not an integer type + _ => unreachable!(), + }; + + if value > max { + return Err(encode::Error::msg(format!( + "number `{}` too large to fit in SQL type `{}`", + value, + ty.name() + ))); } - fn compatible(ty: &MySqlTypeInfo) -> bool { - ty.id().is_integer() - } + Ok(()) } -impl Encode for u8 { - fn encode(&self, ty: &MySqlTypeInfo, out: &mut MySqlOutput<'_>) -> encode::Result<()> { - match ty.id() { - MySqlTypeId::TINYINT_UNSIGNED => {} - - MySqlTypeId::TINYINT if *self > 0x7f => { - return Err(encode::Error::msg(NUMBER_TOO_LARGE)); +macro_rules! impl_type_uint { + ($ty:ty => $sql:ident) => { + impl Type for $ty { + fn type_id() -> MySqlTypeId { + MySqlTypeId::$sql } - _ => {} + fn compatible(ty: &MySqlTypeInfo) -> bool { + ty.id().is_integer() + } } - out.buffer().push(*self); + impl Encode for $ty { + fn encode(&self, _ty: &MySqlTypeInfo, out: &mut MySqlOutput<'_>) -> encode::Result<()> { + out.buffer().extend_from_slice(&self.to_le_bytes()); - Ok(()) - } + Ok(()) + } + } + + impl<'r> Decode<'r, MySql> for $ty { + fn decode(value: MySqlRawValue<'r>) -> decode::Result { + decode_int_or_uint(&value) + } + } + }; } -impl<'r> Decode<'r, MySql> for u8 { - fn decode(value: MySqlRawValue<'r>) -> decode::Result { - decode_int_or_uint(&value) - } -} +impl_type_uint! { u8 => TINYINT_UNSIGNED } +impl_type_uint! { u16 => SMALLINT_UNSIGNED } +impl_type_uint! { u32 => INT_UNSIGNED } +impl_type_uint! { u64 => BIGINT_UNSIGNED }