diff --git a/Cargo.toml b/Cargo.toml index be0aaf50..a39ad537 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -64,3 +64,7 @@ required-features = [ "postgres" ] [[test]] name = "mysql-types" required-features = [ "mysql" ] + +[[test]] +name = "mysql-types-chrono" +required-features = [ "mysql", "chrono" ] diff --git a/sqlx-core/src/mysql/types/chrono.rs b/sqlx-core/src/mysql/types/chrono.rs index 8c87c6f0..a0e2d5fb 100644 --- a/sqlx-core/src/mysql/types/chrono.rs +++ b/sqlx-core/src/mysql/types/chrono.rs @@ -1,78 +1,82 @@ -use byteorder::{ByteOrder, LittleEndian}; -use chrono::{Datelike, NaiveDate, NaiveDateTime, Timelike}; +use std::convert::TryFrom; + +use byteorder::{BigEndian, ByteOrder, LittleEndian}; +use chrono::{DateTime, Datelike, NaiveDate, NaiveDateTime, NaiveTime, Timelike, Utc}; use crate::decode::{Decode, DecodeError}; use crate::encode::Encode; +use crate::io::{Buf, BufMut}; use crate::mysql::protocol::Type; use crate::mysql::types::MySqlTypeMetadata; use crate::mysql::MySql; use crate::types::HasSqlType; -use std::convert::TryFrom; -impl HasSqlType for MySql { +impl HasSqlType> for MySql { fn metadata() -> Self::TypeMetadata { - MySqlTypeMetadata::new(Type::DATETIME) + MySqlTypeMetadata::new(Type::TIMESTAMP) } } -impl Encode for NaiveDateTime { +impl Encode for DateTime { fn encode(&self, buf: &mut Vec) { - // subtract the length byte - let length = Encode::::size_hint(self) - 1; + Encode::::encode(&self.naive_utc(), buf); + } +} - buf.push(length as u8); +impl Decode for DateTime { + fn decode(buf: &[u8]) -> Result { + let naive: NaiveDateTime = Decode::::decode(buf)?; - encode_date(self.date(), buf); + Ok(DateTime::from_utc(naive, Utc)) + } +} - if length >= 7 { - buf.push(self.hour() as u8); - buf.push(self.minute() as u8); - buf.push(self.second() as u8); - } +impl HasSqlType for MySql { + fn metadata() -> Self::TypeMetadata { + MySqlTypeMetadata::new(Type::TIME) + } +} - if length == 11 { - buf.extend_from_slice(&self.timestamp_subsec_micros().to_le_bytes()); - } +impl Encode for NaiveTime { + fn encode(&self, buf: &mut Vec) { + let len = Encode::::size_hint(self) - 1; + buf.push(len as u8); + + // NaiveTime is not negative + buf.push(0); + + // "date on 4 bytes little-endian format" (?) + // https://mariadb.com/kb/en/resultset-row/#teimstamp-binary-encoding + buf.advance(4); + + encode_time(self, len > 9, buf); } fn size_hint(&self) -> usize { - match ( - self.hour(), - self.minute(), - self.second(), - self.timestamp_subsec_micros(), - ) { - // include the length byte - (0, 0, 0, 0) => 5, - (_, _, _, 0) => 8, - (_, _, _, _) => 12, + if self.nanosecond() == 0 { + // if micro_seconds is 0, length is 8 and micro_seconds is not sent + 9 + } else { + // otherwise length is 12 + 13 } } } -impl Decode for NaiveDateTime { - fn decode(raw: &[u8]) -> Result { - let len = raw[0]; +impl Decode for NaiveTime { + fn decode(mut buf: &[u8]) -> Result { + // data length, expecting 8 or 12 (fractional seconds) + let len = buf.get_u8()?; - // TODO: Make an error - assert_ne!(len, 0, "MySQL zero-dates are not supported"); + // is negative : int<1> + let is_negative = buf.get_u8()?; + assert_eq!(is_negative, 0, "Negative dates/times are not supported"); - let date = decode_date(&raw[1..]); + // "date on 4 bytes little-endian format" (?) + // https://mariadb.com/kb/en/resultset-row/#timestamp-binary-encoding + buf.advance(4); - Ok(if len >= 7 { - date.and_hms_micro( - raw[5] as u32, - raw[6] as u32, - raw[7] as u32, - if len == 11 { - LittleEndian::read_u32(&raw[8..]) - } else { - 0 - }, - ) - } else { - date.and_hms(0, 0, 0) - }) + decode_time(len - 5, buf) } } @@ -86,7 +90,7 @@ impl Encode for NaiveDate { fn encode(&self, buf: &mut Vec) { buf.push(4); - encode_date(*self, buf); + encode_date(self, buf); } fn size_hint(&self) -> usize { @@ -95,15 +99,67 @@ impl Encode for NaiveDate { } impl Decode for NaiveDate { - fn decode(raw: &[u8]) -> Result { - // TODO: Return error - assert_eq!(raw[0], 4, "expected only 4 bytes"); - - Ok(decode_date(&raw[1..])) + fn decode(buf: &[u8]) -> Result { + Ok(decode_date(&buf[1..])) } } -fn encode_date(date: NaiveDate, buf: &mut Vec) { +impl HasSqlType for MySql { + fn metadata() -> Self::TypeMetadata { + MySqlTypeMetadata::new(Type::DATETIME) + } +} + +impl Encode for NaiveDateTime { + fn encode(&self, buf: &mut Vec) { + let len = Encode::::size_hint(self) - 1; + buf.push(len as u8); + + encode_date(&self.date(), buf); + + if len > 4 { + encode_time(&self.time(), len > 8, buf); + } + } + + fn size_hint(&self) -> usize { + // to save space the packet can be compressed: + match ( + self.hour(), + self.minute(), + self.second(), + self.timestamp_subsec_nanos(), + ) { + // if hour, minutes, seconds and micro_seconds are all 0, + // length is 4 and no other field is sent + (0, 0, 0, 0) => 5, + + // if micro_seconds is 0, length is 7 + // and micro_seconds is not sent + (_, _, _, 0) => 8, + + // otherwise length is 11 + (_, _, _, _) => 12, + } + } +} + +impl Decode for NaiveDateTime { + fn decode(buf: &[u8]) -> Result { + let len = buf[0]; + let date = decode_date(&buf[1..]); + + let dt = if len > 4 { + date.and_time(decode_time(len - 4, &buf[5..])?) + } else { + date.and_hms(0, 0, 0) + }; + + Ok(dt) + } +} + +fn encode_date(date: &NaiveDate, buf: &mut Vec) { // MySQL supports years from 1000 - 9999 let year = u16::try_from(date.year()) .unwrap_or_else(|_| panic!("NaiveDateTime out of range for Mysql: {}", date)); @@ -113,14 +169,44 @@ fn encode_date(date: NaiveDate, buf: &mut Vec) { buf.push(date.day() as u8); } -fn decode_date(raw: &[u8]) -> NaiveDate { +fn decode_date(buf: &[u8]) -> NaiveDate { NaiveDate::from_ymd( - LittleEndian::read_u16(raw) as i32, - raw[2] as u32, - raw[3] as u32, + LittleEndian::read_u16(buf) as i32, + buf[2] as u32, + buf[3] as u32, ) } +fn encode_time(time: &NaiveTime, include_micros: bool, buf: &mut Vec) { + buf.push(time.hour() as u8); + buf.push(time.minute() as u8); + buf.push(time.second() as u8); + + if include_micros { + buf.put_u32::((time.nanosecond() / 1000) as u32); + } +} + +fn decode_time(len: u8, mut buf: &[u8]) -> Result { + let hour = buf.get_u8()?; + let minute = buf.get_u8()?; + let seconds = buf.get_u8()?; + + let micros = if len > 3 { + // microseconds : int + buf.get_uint::(buf.len())? + } else { + 0 + }; + + Ok(NaiveTime::from_hms_micro( + hour as u32, + minute as u32, + seconds as u32, + micros as u32, + )) +} + #[test] fn test_encode_date_time() { let mut buf = Vec::new(); diff --git a/tests/mysql-types-chrono.rs b/tests/mysql-types-chrono.rs new file mode 100644 index 00000000..3711c2d3 --- /dev/null +++ b/tests/mysql-types-chrono.rs @@ -0,0 +1,81 @@ +use sqlx::types::chrono::{DateTime, NaiveDate, NaiveTime, Utc}; +use sqlx::{mysql::MySqlConnection, Connection, Row}; + +async fn connect() -> anyhow::Result { + Ok(MySqlConnection::open(dotenv::var("DATABASE_URL")?).await?) +} + +#[async_std::test] +async fn mysql_chrono_date() -> anyhow::Result<()> { + let mut conn = connect().await?; + + let value = NaiveDate::from_ymd(2019, 1, 2); + + let row = sqlx::query("SELECT DATE '2019-01-02' = ?, ?") + .bind(&value) + .bind(&value) + .fetch_one(&mut conn) + .await?; + + assert!(row.get::(0)); + assert_eq!(value, row.get(1)); + + Ok(()) +} + +#[async_std::test] +async fn mysql_chrono_date_time() -> anyhow::Result<()> { + let mut conn = connect().await?; + + let value = NaiveDate::from_ymd(2019, 1, 2).and_hms(5, 10, 20); + + let row = sqlx::query("SELECT '2019-01-02 05:10:20' = ?, ?") + .bind(&value) + .bind(&value) + .fetch_one(&mut conn) + .await?; + + assert!(row.get::(0)); + assert_eq!(value, row.get(1)); + + Ok(()) +} + +#[async_std::test] +async fn mysql_chrono_time() -> anyhow::Result<()> { + let mut conn = connect().await?; + + let value = NaiveTime::from_hms_micro(5, 10, 20, 115100); + + let row = sqlx::query("SELECT TIME '05:10:20.115100' = ?, TIME '05:10:20.115100'") + .bind(&value) + .fetch_one(&mut conn) + .await?; + + assert!(row.get::(0)); + assert_eq!(value, row.get(1)); + + Ok(()) +} + +#[async_std::test] +async fn mysql_chrono_timestamp() -> anyhow::Result<()> { + let mut conn = connect().await?; + + let value = DateTime::::from_utc( + NaiveDate::from_ymd(2019, 1, 2).and_hms_micro(5, 10, 20, 115100), + Utc, + ); + + let row = sqlx::query( + "SELECT TIMESTAMP '2019-01-02 05:10:20.115100' = ?, TIMESTAMP '2019-01-02 05:10:20.115100'", + ) + .bind(&value) + .fetch_one(&mut conn) + .await?; + + assert!(row.get::(0)); + assert_eq!(value, row.get(1)); + + Ok(()) +} diff --git a/tests/mysql-types.rs b/tests/mysql-types.rs index 2d7544f4..d612aab3 100644 --- a/tests/mysql-types.rs +++ b/tests/mysql-types.rs @@ -1,13 +1,14 @@ -use sqlx::{mysql::MySqlConnection, Connection as _, Row}; +use sqlx::{mysql::MySqlConnection, Connection, Row}; + +async fn connect() -> anyhow::Result { + Ok(MySqlConnection::open(dotenv::var("DATABASE_URL")?).await?) +} macro_rules! test { ($name:ident: $ty:ty: $($text:literal == $value:expr),+) => { #[async_std::test] - async fn $name () -> sqlx::Result<()> { - let mut conn = - MySqlConnection::open( - &dotenv::var("DATABASE_URL").expect("DATABASE_URL must be set") - ).await?; + async fn $name () -> anyhow::Result<()> { + let mut conn = connect().await?; $( let row = sqlx::query(&format!("SELECT {} = ?, ? as _1", $text)) @@ -29,12 +30,17 @@ macro_rules! test { } test!(mysql_bool: bool: "false" == false, "true" == true); + test!(mysql_tiny_unsigned: u8: "253" == 253_u8); test!(mysql_tiny: i8: "5" == 5_i8); + test!(mysql_medium_unsigned: u16: "21415" == 21415_u16); test!(mysql_short: i16: "21415" == 21415_i16); + test!(mysql_long_unsigned: u32: "2141512" == 2141512_u32); test!(mysql_long: i32: "2141512" == 2141512_i32); + test!(mysql_longlong_unsigned: u64: "2141512" == 2141512_u64); test!(mysql_longlong: i64: "2141512" == 2141512_i64); + test!(mysql_string: String: "'helloworld'" == "helloworld");