diff --git a/Cargo.toml b/Cargo.toml index f3f53a6ea..5640e308a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ unstable = [ "sqlx-core/unstable" ] postgres = [ "sqlx-core/postgres", "sqlx-macros/postgres" ] mariadb = [ "sqlx-core/mariadb", "sqlx-macros/mariadb" ] macros = [ "sqlx-macros", "proc-macro-hack" ] +chrono = ["sqlx-core/chrono", "sqlx-macros/chrono"] uuid = [ "sqlx-core/uuid", "sqlx-macros/uuid" ] [dependencies] diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index 65aa53003..e7eb36317 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -20,6 +20,7 @@ async-std = { version = "1.2.0", default-features = false, features = [ "unstabl async-stream = "0.2.0" bitflags = "1.2.1" byteorder = { version = "1.3.2", default-features = false } +chrono = { version = "0.4", optional = true } futures-channel = "0.3.1" futures-core = "0.3.1" futures-util = "0.3.1" diff --git a/sqlx-core/src/io/buf.rs b/sqlx-core/src/io/buf.rs index 65cc1b7a0..424f7c070 100644 --- a/sqlx-core/src/io/buf.rs +++ b/sqlx-core/src/io/buf.rs @@ -1,6 +1,6 @@ use byteorder::ByteOrder; use memchr::memchr; -use std::{io, str}; +use std::{io, slice, str}; pub trait Buf { fn advance(&mut self, cnt: usize); @@ -98,3 +98,15 @@ impl<'a> Buf for &'a [u8] { Ok(s) } } + +pub trait ToBuf { + fn to_buf(&self) -> &[u8]; +} + +impl ToBuf for [u8] { + fn to_buf(&self) -> &[u8] { self } +} + +impl ToBuf for u8 { + fn to_buf(&self) -> &[u8] { slice::from_ref(self) } +} diff --git a/sqlx-core/src/io/mod.rs b/sqlx-core/src/io/mod.rs index a2098395c..922a1af00 100644 --- a/sqlx-core/src/io/mod.rs +++ b/sqlx-core/src/io/mod.rs @@ -5,4 +5,4 @@ mod buf; mod buf_mut; mod byte_str; -pub use self::{buf::Buf, buf_mut::BufMut, buf_stream::BufStream, byte_str::ByteStr}; +pub use self::{buf::{Buf, ToBuf}, buf_mut::BufMut, buf_stream::BufStream, byte_str::ByteStr}; diff --git a/sqlx-core/src/macros.rs b/sqlx-core/src/macros.rs index 296cc9a7e..e13567c77 100644 --- a/sqlx-core/src/macros.rs +++ b/sqlx-core/src/macros.rs @@ -3,18 +3,12 @@ #[macro_export] macro_rules! __bytes_builder ( ($($b: expr), *) => {{ - use bytes::Buf; - use bytes::IntoBuf; - use bytes::BufMut; + use $crate::io::ToBuf; - let mut bytes = bytes::BytesMut::new(); + let mut buf = Vec::new(); $( - { - let buf = $b.into_buf(); - bytes.reserve(buf.remaining()); - bytes.put(buf); - } + buf.extend_from_slice($b.to_buf()); )* - bytes.freeze() + buf }} ); diff --git a/sqlx-core/src/mariadb/protocol/binary/com_stmt_prepare_ok.rs b/sqlx-core/src/mariadb/protocol/binary/com_stmt_prepare_ok.rs index bee168213..a6c0578d5 100644 --- a/sqlx-core/src/mariadb/protocol/binary/com_stmt_prepare_ok.rs +++ b/sqlx-core/src/mariadb/protocol/binary/com_stmt_prepare_ok.rs @@ -52,7 +52,7 @@ mod tests { use crate::__bytes_builder; #[test] - fn it_decodes_com_stmt_prepare_ok() -> io::Result<()> { + fn it_decodes_com_stmt_prepare_ok() -> crate::Result<()> { #[rustfmt::skip] let buf = &__bytes_builder!( // int<1> 0x00 COM_STMT_PREPARE_OK header diff --git a/sqlx-core/src/mariadb/protocol/response/eof.rs b/sqlx-core/src/mariadb/protocol/response/eof.rs index 6a46eafb3..b6296f48f 100644 --- a/sqlx-core/src/mariadb/protocol/response/eof.rs +++ b/sqlx-core/src/mariadb/protocol/response/eof.rs @@ -38,7 +38,7 @@ mod test { use std::io; #[test] - fn it_decodes_eof_packet() -> io::Result<()> { + fn it_decodes_eof_packet() -> crate::Result<()> { #[rustfmt::skip] let buf = __bytes_builder!( // int<1> 0xfe : EOF header diff --git a/sqlx-core/src/mariadb/protocol/response/ok.rs b/sqlx-core/src/mariadb/protocol/response/ok.rs index e71b49b52..eabc3216f 100644 --- a/sqlx-core/src/mariadb/protocol/response/ok.rs +++ b/sqlx-core/src/mariadb/protocol/response/ok.rs @@ -69,7 +69,7 @@ mod test { use crate::__bytes_builder; #[test] - fn it_decodes_ok_packet() -> io::Result<()> { + fn it_decodes_ok_packet() -> crate::Result<()> { #[rustfmt::skip] let buf = __bytes_builder!( // 0x00 : OK_Packet header or (0xFE if CLIENT_DEPRECATE_EOF is set) diff --git a/sqlx-core/src/mariadb/types/boolean.rs b/sqlx-core/src/mariadb/types/boolean.rs index 440413c16..a97d7cdcb 100644 --- a/sqlx-core/src/mariadb/types/boolean.rs +++ b/sqlx-core/src/mariadb/types/boolean.rs @@ -10,7 +10,7 @@ impl HasSqlType for MariaDb { fn metadata() -> MariaDbTypeMetadata { MariaDbTypeMetadata { // MYSQL_TYPE_TINY - field_type: FieldType(1), + field_type: FieldType::MYSQL_TYPE_TINY, param_flag: ParameterFlag::empty(), } } diff --git a/sqlx-core/src/mariadb/types/character.rs b/sqlx-core/src/mariadb/types/character.rs index 76cf62237..81eca164d 100644 --- a/sqlx-core/src/mariadb/types/character.rs +++ b/sqlx-core/src/mariadb/types/character.rs @@ -14,7 +14,7 @@ impl HasSqlType for MariaDb { fn metadata() -> MariaDbTypeMetadata { MariaDbTypeMetadata { // MYSQL_TYPE_VAR_STRING - field_type: FieldType(254), + field_type: FieldType::MYSQL_TYPE_VAR_STRING, param_flag: ParameterFlag::empty(), } } diff --git a/sqlx-core/src/mariadb/types/chrono.rs b/sqlx-core/src/mariadb/types/chrono.rs new file mode 100644 index 000000000..36a465e48 --- /dev/null +++ b/sqlx-core/src/mariadb/types/chrono.rs @@ -0,0 +1,177 @@ +use crate::{HasSqlType, MariaDb, HasTypeMetadata, Encode, Decode}; +use chrono::{NaiveDateTime, Datelike, Timelike, NaiveTime, NaiveDate}; +use crate::mariadb::types::MariaDbTypeMetadata; +use crate::mariadb::protocol::{FieldType, ParameterFlag}; +use crate::encode::IsNull; + +use crate::io::Buf; + +use std::convert::{TryFrom, TryInto}; +use byteorder::{LittleEndian, ByteOrder}; +use chrono::format::Item::Literal; + +impl HasSqlType for MariaDb { + fn metadata() -> Self::TypeMetadata { + MariaDbTypeMetadata { + field_type: FieldType::MYSQL_TYPE_DATETIME, + param_flag: ParameterFlag::empty() + } + } +} + +impl Encode for NaiveDateTime { + fn encode(&self, buf: &mut Vec) -> IsNull { + // subtract the length byte + let length = Encode::::size_hint(self) - 1; + + buf.push(length as u8); + + encode_date(self.date(), buf); + + if length >= 7 { + buf.push(self.hour() as u8); + buf.push(self.minute() as u8); + buf.push(self.second() as u8); + } + + if length == 11 { + buf.extend_from_slice(&self.timestamp_subsec_micros().to_le_bytes()); + } + + IsNull::No + } + + fn size_hint(&self) -> usize { + match (self.hour(), self.minute(), self.second(), self.timestamp_subsec_micros()) { + // include the length byte + (0, 0, 0, 0) => 5, + (_, _, _, 0) => 8, + (_, _, _, _) => 12, + } + } +} + +impl Decode for NaiveDateTime { + fn decode(raw: Option<&[u8]>) -> Self { + let raw = raw.unwrap(); + let len = raw[0]; + assert_ne!(len, 0, "MySQL zero-dates are not supported"); + + let date = decode_date(&raw[1..]); + + if len >= 7 { + date.and_hms_micro( + raw[5] as u32, + raw[6] as u32, + raw[7] as u32, + if len == 11 { + LittleEndian::read_u32(&raw[8..]) + } else { + 0 + } + ) + } else { + date.and_hms(0, 0, 0) + } + } +} + +impl HasSqlType for MariaDb { + fn metadata() -> Self::TypeMetadata { + MariaDbTypeMetadata { + field_type: FieldType::MYSQL_TYPE_DATE, + param_flag: ParameterFlag::empty() + } + } +} + +impl Encode for NaiveDate { + fn encode(&self, buf: &mut Vec) -> IsNull { + buf.push(4); + encode_date(*self, buf); + IsNull::No + } + + fn size_hint(&self) -> usize { + 5 + } +} + +impl Decode for NaiveDate { + fn decode(raw: Option<&[u8]>) -> Self { + let raw = raw.unwrap(); + assert_eq!(raw[0], 4, "expected only 4 bytes"); + decode_date(&raw[1..]) + } +} + +fn encode_date(date: NaiveDate, buf: &mut Vec) { + // MySQL supports years from 1000 - 9999 + let year = u16::try_from(date.year()) + .unwrap_or_else(|_| panic!("NaiveDateTime out of range for MariaDB: {}", date)); + + buf.extend_from_slice(&year.to_le_bytes()); + buf.push(date.month() as u8); + buf.push(date.day() as u8); +} + +fn decode_date(raw: &[u8]) -> NaiveDate { + NaiveDate::from_ymd( + LittleEndian::read_u16(raw) as i32, + raw[2] as u32, + raw[3] as u32 + ) +} + +#[test] +fn test_encode_date_time() { + let mut buf = Vec::new(); + + // test values from https://dev.mysql.com/doc/internals/en/binary-protocol-value.html + let date1: NaiveDateTime = "2010-10-17T19:27:30.000001".parse().unwrap(); + Encode::::encode(&date1, &mut buf); + assert_eq!(*buf, [11, 218, 7, 10, 17, 19, 27, 30, 1, 0, 0, 0]); + + buf.clear(); + + let date2: NaiveDateTime = "2010-10-17T19:27:30".parse().unwrap(); + Encode::::encode(&date2, &mut buf); + assert_eq!(*buf, [7, 218, 7, 10, 17, 19, 27, 30]); + + buf.clear(); + + let date3: NaiveDateTime = "2010-10-17T00:00:00".parse().unwrap(); + Encode::::encode(&date3, &mut buf); + assert_eq!(*buf, [4, 218, 7, 10, 17]); +} + +#[test] +fn test_decode_date_time() { + // test values from https://dev.mysql.com/doc/internals/en/binary-protocol-value.html + let buf = [11, 218, 7, 10, 17, 19, 27, 30, 1, 0, 0, 0]; + let date1 = >::decode(Some(&buf)); + assert_eq!(date1.to_string(), "2010-10-17 19:27:30.000001"); + + let buf = [7, 218, 7, 10, 17, 19, 27, 30]; + let date2 = >::decode(Some(&buf)); + assert_eq!(date2.to_string(), "2010-10-17 19:27:30"); + + let buf = [4, 218, 7, 10, 17]; + let date3 = >::decode(Some(&buf)); + assert_eq!(date3.to_string(), "2010-10-17 00:00:00"); +} + +#[test] +fn test_encode_date() { + let mut buf = Vec::new(); + let date: NaiveDate = "2010-10-17".parse().unwrap(); + Encode::::encode(&date, &mut buf); + assert_eq!(*buf, [4, 218, 7, 10, 17]); +} + +#[test] +fn test_decode_date() { + let buf = [4, 218, 7, 10, 17]; + let date = >::decode(Some(&buf)); + assert_eq!(date.to_string(), "2010-10-17"); +} diff --git a/sqlx-core/src/mariadb/types/mod.rs b/sqlx-core/src/mariadb/types/mod.rs index e9fc92ab3..4e5b08f54 100644 --- a/sqlx-core/src/mariadb/types/mod.rs +++ b/sqlx-core/src/mariadb/types/mod.rs @@ -9,6 +9,9 @@ pub mod boolean; pub mod character; pub mod numeric; +#[cfg(feature = "chrono")] +pub mod chrono; + #[derive(Debug)] pub struct MariaDbTypeMetadata { pub field_type: FieldType, diff --git a/sqlx-core/src/mariadb/types/numeric.rs b/sqlx-core/src/mariadb/types/numeric.rs index b8e80b55e..e33d8feec 100644 --- a/sqlx-core/src/mariadb/types/numeric.rs +++ b/sqlx-core/src/mariadb/types/numeric.rs @@ -11,7 +11,7 @@ impl HasSqlType for MariaDb { #[inline] fn metadata() -> MariaDbTypeMetadata { MariaDbTypeMetadata { - field_type: FieldType(1), + field_type: FieldType::MYSQL_TYPE_TINY, param_flag: ParameterFlag::empty(), } } diff --git a/sqlx-core/src/postgres/protocol/data_row.rs b/sqlx-core/src/postgres/protocol/data_row.rs index 8c4dbcd92..78c45ed54 100644 --- a/sqlx-core/src/postgres/protocol/data_row.rs +++ b/sqlx-core/src/postgres/protocol/data_row.rs @@ -65,6 +65,7 @@ impl Debug for DataRow { #[cfg(test)] mod tests { use super::{DataRow, Decode}; + use crate::Row; const DATA_ROW: &[u8] = b"\0\x03\0\0\0\x011\0\0\0\x012\0\0\0\x013"; @@ -74,9 +75,9 @@ mod tests { assert_eq!(m.len(), 3); - assert_eq!(m.get(0), Some(&b"1"[..])); - assert_eq!(m.get(1), Some(&b"2"[..])); - assert_eq!(m.get(2), Some(&b"3"[..])); + assert_eq!(m.get_raw(0), Some(&b"1"[..])); + assert_eq!(m.get_raw(1), Some(&b"2"[..])); + assert_eq!(m.get_raw(2), Some(&b"3"[..])); assert_eq!( format!("{:?}", m), diff --git a/sqlx-core/src/postgres/types/chrono.rs b/sqlx-core/src/postgres/types/chrono.rs new file mode 100644 index 000000000..b8cb0d97f --- /dev/null +++ b/sqlx-core/src/postgres/types/chrono.rs @@ -0,0 +1,208 @@ +use crate::{Decode, Postgres, Encode, HasSqlType, HasTypeMetadata}; +use chrono::{NaiveTime, Timelike, NaiveDate, TimeZone, DateTime, NaiveDateTime, Utc, Local, Duration, Date}; +use crate::postgres::types::{PostgresTypeMetadata, PostgresTypeFormat}; +use crate::encode::IsNull; + +use std::convert::TryInto; + +use std::mem::size_of; + +postgres_metadata!( + // time + NaiveTime: PostgresTypeMetadata { + format: PostgresTypeFormat::Binary, + oid: 1083, + array_oid: 1183 + }, + // date + NaiveDate: PostgresTypeMetadata { + format: PostgresTypeFormat::Binary, + oid: 1082, + array_oid: 1182 + }, + // timestamp + NaiveDateTime: PostgresTypeMetadata { + format: PostgresTypeFormat::Binary, + oid: 1114, + array_oid: 1115 + }, + // timestamptz + { Tz: TimeZone } DateTime: PostgresTypeMetadata { + format: PostgresTypeFormat::Binary, + oid: 1184, + array_oid: 1185 + }, + // Date is not covered as Postgres does not have a "date with timezone" type +); + +fn decode>(raw: Option<&[u8]>) -> T { + Decode::::decode(raw) +} + +impl Decode for NaiveTime { + fn decode(raw: Option<&[u8]>) -> Self { + let micros: i64 = decode(raw); + NaiveTime::from_hms(0, 0, 0) + Duration::microseconds(micros) + } +} + +impl Encode for NaiveTime { + fn encode(&self, buf: &mut Vec) -> IsNull { + let micros = (*self - NaiveTime::from_hms(0, 0, 0)) + .num_microseconds() + .expect("shouldn't overflow"); + + Encode::::encode(µs, buf) + } + + fn size_hint(&self) -> usize { + size_of::() + } +} + +impl Decode for NaiveDate { + fn decode(raw: Option<&[u8]>) -> Self { + let days: i32 = decode(raw); + NaiveDate::from_ymd(2000, 1, 1) + Duration::days(days as i64) + } +} + +impl Encode for NaiveDate { + fn encode(&self, buf: &mut Vec) -> IsNull { + let days: i32 = self.signed_duration_since(NaiveDate::from_ymd(2000, 1, 1)) + .num_days() + .try_into() + .unwrap_or_else(|_| panic!("NaiveDate out of range for Postgres: {:?}", self)); + + Encode::::encode(&days, buf) + } + + fn size_hint(&self) -> usize { + size_of::() + } +} + +impl Decode for NaiveDateTime { + fn decode(raw: Option<&[u8]>) -> Self { + let micros: i64 = decode(raw); + postgres_epoch().naive_utc() + .checked_add_signed(Duration::microseconds(micros)) + .unwrap_or_else(|| panic!("Postgres timestamp out of range for NaiveDateTime: {:?}", micros)) + } +} + +impl Encode for NaiveDateTime { + fn encode(&self, buf: &mut Vec) -> IsNull { + let micros = self.signed_duration_since(postgres_epoch().naive_utc()) + .num_microseconds() + .unwrap_or_else(|| panic!("NaiveDateTime out of range for Postgres: {:?}", self)); + + Encode::::encode(µs, buf) + } + + fn size_hint(&self) -> usize { + size_of::() + } +} + +impl Decode for DateTime { + fn decode(raw: Option<&[u8]>) -> Self { + let date_time = >::decode(raw); + DateTime::from_utc(date_time, Utc) + } +} + +impl Decode for DateTime { + fn decode(raw: Option<&[u8]>) -> Self { + let date_time = >::decode(raw); + Local.from_utc_datetime(&date_time) + } +} + +impl Encode for DateTime where Tz::Offset: Copy { + fn encode(&self, buf: &mut Vec) -> IsNull { + Encode::::encode(&self.naive_utc(), buf) + } + + fn size_hint(&self) -> usize { + size_of::() + } +} + +fn postgres_epoch() -> DateTime { + Utc.ymd(2000, 1, 1).and_hms(0, 0, 0) +} + +#[test] +fn test_encode_datetime() { + let mut buf = Vec::new(); + + let date = postgres_epoch(); + Encode::::encode(&date, &mut buf); + assert_eq!(buf, [0; 8]); + buf.clear(); + + // one hour past epoch + let date2 = postgres_epoch() + Duration::hours(1); + Encode::::encode(&date2, &mut buf); + assert_eq!(buf, 3_600_000_000i64.to_be_bytes()); + buf.clear(); + + // some random date + let date3: NaiveDateTime = "2019-12-11T11:01:05".parse().unwrap(); + let expected = dbg!((date3 - postgres_epoch().naive_utc()).num_microseconds().unwrap()); + Encode::::encode(&date3, &mut buf); + assert_eq!(buf, expected.to_be_bytes()); + buf.clear(); +} + +#[test] +fn test_decode_datetime() { + let buf = [0u8; 8]; + let date: NaiveDateTime = Decode::::decode(Some(&buf)); + assert_eq!(date.to_string(), "2000-01-01 00:00:00"); + + let buf = 3_600_000_000i64.to_be_bytes(); + let date: NaiveDateTime = Decode::::decode(Some(&buf)); + assert_eq!(date.to_string(), "2000-01-01 01:00:00"); + + let buf = 629_377_265_000_000i64.to_be_bytes(); + let date: NaiveDateTime = Decode::::decode(Some(&buf)); + assert_eq!(date.to_string(), "2019-12-11 11:01:05"); +} + +#[test] +fn test_encode_date() { + let mut buf = Vec::new(); + + let date = NaiveDate::from_ymd(2000, 1, 1); + Encode::::encode(&date, &mut buf); + assert_eq!(buf, [0; 4]); + buf.clear(); + + let date2 = NaiveDate::from_ymd(2001, 1, 1); + Encode::::encode(&date2, &mut buf); + // 2000 was a leap year + assert_eq!(buf, 366i32.to_be_bytes()); + buf.clear(); + + let date3 = NaiveDate::from_ymd(2019, 12, 11); + Encode::::encode(&date3, &mut buf); + assert_eq!(buf, 7284i32.to_be_bytes()); + buf.clear(); +} + +#[test] +fn test_decode_date() { + let buf = [0; 4]; + let date: NaiveDate = Decode::::decode(Some(&buf)); + assert_eq!(date.to_string(), "2000-01-01"); + + let buf = 366i32.to_be_bytes(); + let date: NaiveDate = Decode::::decode(Some(&buf)); + assert_eq!(date.to_string(), "2001-01-01"); + + let buf = 7284i32.to_be_bytes(); + let date: NaiveDate = Decode::::decode(Some(&buf)); + assert_eq!(date.to_string(), "2019-12-11"); +} diff --git a/sqlx-core/src/postgres/types/mod.rs b/sqlx-core/src/postgres/types/mod.rs index 7dab4e385..1e57396db 100644 --- a/sqlx-core/src/postgres/types/mod.rs +++ b/sqlx-core/src/postgres/types/mod.rs @@ -30,6 +30,18 @@ use super::Postgres; use crate::types::{HasTypeMetadata, TypeMetadata}; +macro_rules! postgres_metadata { + ($($({ $($typarams:tt)* })? $type:path: $meta:expr),*$(,)?) => { + $( + impl$(<$($typarams)*>)? crate::types::HasSqlType<$type> for Postgres { + fn metadata() -> PostgresTypeMetadata { + $meta + } + } + )* + }; +} + mod binary; mod boolean; mod character; @@ -38,6 +50,9 @@ mod numeric; #[cfg(feature = "uuid")] mod uuid; +#[cfg(feature = "chrono")] +mod chrono; + pub enum PostgresTypeFormat { Text = 0, Binary = 1, diff --git a/sqlx-macros/Cargo.toml b/sqlx-macros/Cargo.toml index 57f9ec853..c1b1d74b5 100644 --- a/sqlx-macros/Cargo.toml +++ b/sqlx-macros/Cargo.toml @@ -19,6 +19,7 @@ quote = "1.0.2" url = "2.1.0" [features] +chrono = ["sqlx/chrono"] mariadb = ["sqlx/mariadb"] postgres = ["sqlx/postgres"] uuid = ["sqlx/uuid"]