feat(mysql): impl Type for u8, u16, u32, and u64

This commit is contained in:
Ryan Leckey 2021-02-26 00:20:38 -08:00
parent 730439fcf7
commit c9252570a9
No known key found for this signature in database
GPG Key ID: F8AA68C235AB08C9

View File

@ -10,16 +10,12 @@ use sqlx_core::{Decode, Encode, Type};
use crate::type_info::MySqlTypeInfo;
use crate::{MySql, MySqlOutput, MySqlRawValue, MySqlRawValueFormat, MySqlTypeId};
// https://dev.mysql.com/doc/internals/en/binary-protocol-value.html#packet-ProtocolBinary
const NUMBER_TOO_LARGE: &str = "number too large to fit in target type";
// shared among all Decode impls for unsigned and signed integers
fn decode_int_or_uint<T>(value: &MySqlRawValue<'_>) -> decode::Result<T>
pub(super) fn decode_int_or_uint<T>(value: &MySqlRawValue<'_>) -> decode::Result<T>
where
T: TryFrom<i64> + TryFrom<u64> + FromStr,
<T as TryFrom<i64>>::Error: 'static + StdError + Send + Sync,
T: TryFrom<u64> + FromStr,
T: TryFrom<i64> + FromStr,
<T as TryFrom<u64>>::Error: 'static + StdError + Send + Sync,
<T as TryFrom<i64>>::Error: 'static + StdError + Send + Sync,
<T as FromStr>::Err: 'static + StdError + Send + Sync,
{
if value.format() == MySqlRawValueFormat::Text {
@ -27,51 +23,76 @@ where
}
let mut bytes = value.as_bytes()?;
// start from u64 if the value is marked as unsigned
// otherwise start from i64
let is_unsigned = value.type_info().id().is_unsigned();
// pull at most 8 bytes from the buffer
let len = cmp::max(bytes.len(), 8);
let size = cmp::max(bytes.len(), 8);
Ok(if is_unsigned {
bytes.get_uint_le(len).try_into()?
bytes.get_uint_le(size).try_into()?
} else {
bytes.get_int_le(len).try_into()?
bytes.get_int_le(size).try_into()?
})
}
impl Type<MySql> for u8 {
fn type_id() -> MySqlTypeId {
MySqlTypeId::TINYINT_UNSIGNED
// check that the incoming value is not too large
// to fit into the target SQL type
fn ensure_not_too_large(value: u64, ty: &MySqlTypeInfo) -> encode::Result<()> {
let max = match ty.id() {
MySqlTypeId::TINYINT => i8::MAX as _,
MySqlTypeId::SMALLINT => i16::MAX as _,
MySqlTypeId::MEDIUMINT => 0x7F_FF_FF as _,
MySqlTypeId::INT => i32::MAX as _,
MySqlTypeId::BIGINT => i64::MAX as _,
MySqlTypeId::TINYINT_UNSIGNED => u8::MAX as _,
MySqlTypeId::SMALLINT_UNSIGNED => u16::MAX as _,
MySqlTypeId::MEDIUMINT_UNSIGNED => 0xFF_FF_FF as _,
MySqlTypeId::INT_UNSIGNED => u32::MAX as _,
MySqlTypeId::BIGINT_UNSIGNED => u64::MAX,
// not an integer type
_ => unreachable!(),
};
if value > max {
return Err(encode::Error::msg(format!(
"number `{}` too large to fit in SQL type `{}`",
value,
ty.name()
)));
}
fn compatible(ty: &MySqlTypeInfo) -> bool {
ty.id().is_integer()
}
Ok(())
}
impl Encode<MySql> for u8 {
fn encode(&self, ty: &MySqlTypeInfo, out: &mut MySqlOutput<'_>) -> encode::Result<()> {
match ty.id() {
MySqlTypeId::TINYINT_UNSIGNED => {}
MySqlTypeId::TINYINT if *self > 0x7f => {
return Err(encode::Error::msg(NUMBER_TOO_LARGE));
macro_rules! impl_type_uint {
($ty:ty => $sql:ident) => {
impl Type<MySql> for $ty {
fn type_id() -> MySqlTypeId {
MySqlTypeId::$sql
}
_ => {}
fn compatible(ty: &MySqlTypeInfo) -> bool {
ty.id().is_integer()
}
}
out.buffer().push(*self);
impl Encode<MySql> for $ty {
fn encode(&self, _ty: &MySqlTypeInfo, out: &mut MySqlOutput<'_>) -> encode::Result<()> {
out.buffer().extend_from_slice(&self.to_le_bytes());
Ok(())
}
Ok(())
}
}
impl<'r> Decode<'r, MySql> for $ty {
fn decode(value: MySqlRawValue<'r>) -> decode::Result<Self> {
decode_int_or_uint(&value)
}
}
};
}
impl<'r> Decode<'r, MySql> for u8 {
fn decode(value: MySqlRawValue<'r>) -> decode::Result<Self> {
decode_int_or_uint(&value)
}
}
impl_type_uint! { u8 => TINYINT_UNSIGNED }
impl_type_uint! { u16 => SMALLINT_UNSIGNED }
impl_type_uint! { u32 => INT_UNSIGNED }
impl_type_uint! { u64 => BIGINT_UNSIGNED }