diff --git a/Cargo.toml b/Cargo.toml index 736f49cd..a27399e9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -92,6 +92,10 @@ required-features = [ "mysql" ] name = "mysql-raw" required-features = [ "mysql" ] +[[test]] +name = "mysql-derives" +required-features = [ "mysql", "macros" ] + [[test]] name = "postgres" required-features = [ "postgres" ] @@ -105,12 +109,9 @@ name = "postgres-types" required-features = [ "postgres" ] [[test]] -name = "mysql-types" -required-features = [ "mysql" ] +name = "postgres-derives" +required-features = [ "postgres", "macros" ] [[test]] -name = "derives" -required-features = [ "macros" ] - -[profile.release] -lto = true +name = "mysql-types" +required-features = [ "mysql" ] diff --git a/sqlx-core/src/database.rs b/sqlx-core/src/database.rs index 70dea800..c9a11dcb 100644 --- a/sqlx-core/src/database.rs +++ b/sqlx-core/src/database.rs @@ -9,7 +9,7 @@ use crate::types::TypeInfo; /// A database driver. /// /// This trait encapsulates a complete driver implementation to a specific -/// database (e.g., MySQL, Postgres). +/// database (e.g., **MySQL**, **Postgres**). pub trait Database where Self: Sized + Send + 'static, @@ -29,21 +29,54 @@ where /// The Rust type of table identifiers for this database. type TableId: Display + Clone; - type RawBuffer; + /// The Rust type used as the buffer when encoding arguments. + /// + /// For example, **Postgres** and **MySQL** use `Vec`; however, **SQLite** uses `Vec`. + type RawBuffer: Default; } +/// Associate [`Database`] with a `RawValue` of a generic lifetime. +/// +/// --- +/// +/// The upcoming Rust feature, [Generic Associated Types], should obviate +/// the need for this trait. +/// +/// [Generic Associated Types]: https://www.google.com/search?q=generic+associated+types+rust&oq=generic+associated+types+rust&aqs=chrome..69i57j0l5.3327j0j7&sourceid=chrome&ie=UTF-8 pub trait HasRawValue<'c> { + /// The Rust type used to hold a not-yet-decoded value that has just been + /// received from the database. + /// + /// For example, **Postgres** and **MySQL** use `&'c [u8]`; however, **SQLite** uses `SqliteValue<'c>`. type RawValue; } +/// Associate [`Database`] with a [`Cursor`] of a generic lifetime. +/// +/// --- +/// +/// The upcoming Rust feature, [Generic Associated Types], should obviate +/// the need for this trait. +/// +/// [Generic Associated Types]: https://www.google.com/search?q=generic+associated+types+rust&oq=generic+associated+types+rust&aqs=chrome..69i57j0l5.3327j0j7&sourceid=chrome&ie=UTF-8 pub trait HasCursor<'c, 'q> { type Database: Database; + /// The concrete `Cursor` implementation for this database. type Cursor: Cursor<'c, 'q, Database = Self::Database>; } +/// Associate [`Database`] with a [`Row`] of a generic lifetime. +/// +/// --- +/// +/// The upcoming Rust feature, [Generic Associated Types], should obviate +/// the need for this trait. +/// +/// [Generic Associated Types]: https://www.google.com/search?q=generic+associated+types+rust&oq=generic+associated+types+rust&aqs=chrome..69i57j0l5.3327j0j7&sourceid=chrome&ie=UTF-8 pub trait HasRow<'c> { type Database: Database; + /// The concrete `Row` implementation for this database. type Row: Row<'c, Database = Self::Database>; } diff --git a/sqlx-core/src/mysql/protocol/row.rs b/sqlx-core/src/mysql/protocol/row.rs index 384db51c..15096376 100644 --- a/sqlx-core/src/mysql/protocol/row.rs +++ b/sqlx-core/src/mysql/protocol/row.rs @@ -126,6 +126,7 @@ impl<'c> Row<'c> { | TypeId::LONG_BLOB | TypeId::CHAR | TypeId::TEXT + | TypeId::ENUM | TypeId::VAR_CHAR => { let (len_size, len) = get_lenenc(&buffer[index..]); diff --git a/sqlx-core/src/mysql/protocol/type.rs b/sqlx-core/src/mysql/protocol/type.rs index 38ecb453..792ccb57 100644 --- a/sqlx-core/src/mysql/protocol/type.rs +++ b/sqlx-core/src/mysql/protocol/type.rs @@ -39,6 +39,9 @@ type_id_consts! { pub const VAR_CHAR: TypeId = TypeId(253); // or VAR_BINARY pub const TEXT: TypeId = TypeId(252); // or BLOB + // Enum + pub const ENUM: TypeId = TypeId(247); + // More Bytes pub const TINY_BLOB: TypeId = TypeId(249); pub const MEDIUM_BLOB: TypeId = TypeId(250); diff --git a/sqlx-core/src/mysql/types/mod.rs b/sqlx-core/src/mysql/types/mod.rs index 3c1311b0..2c4421f1 100644 --- a/sqlx-core/src/mysql/types/mod.rs +++ b/sqlx-core/src/mysql/types/mod.rs @@ -85,6 +85,7 @@ impl TypeInfo for MySqlTypeInfo { | TypeId::TINY_BLOB | TypeId::MEDIUM_BLOB | TypeId::LONG_BLOB + | TypeId::ENUM if (self.is_binary == other.is_binary) && match other.id { TypeId::VAR_CHAR @@ -92,7 +93,8 @@ impl TypeInfo for MySqlTypeInfo { | TypeId::CHAR | TypeId::TINY_BLOB | TypeId::MEDIUM_BLOB - | TypeId::LONG_BLOB => true, + | TypeId::LONG_BLOB + | TypeId::ENUM => true, _ => false, } => diff --git a/sqlx-core/src/postgres/mod.rs b/sqlx-core/src/postgres/mod.rs index f1dd5997..853514ac 100644 --- a/sqlx-core/src/postgres/mod.rs +++ b/sqlx-core/src/postgres/mod.rs @@ -21,7 +21,7 @@ mod row; mod sasl; mod stream; mod tls; -mod types; +pub mod types; /// An alias for [`Pool`][crate::Pool], specialized for **Postgres**. pub type PgPool = crate::pool::Pool; diff --git a/sqlx-core/src/postgres/row.rs b/sqlx-core/src/postgres/row.rs index 9158b2a7..c9224795 100644 --- a/sqlx-core/src/postgres/row.rs +++ b/sqlx-core/src/postgres/row.rs @@ -11,6 +11,7 @@ use crate::row::{ColumnIndex, Row}; /// A value from Postgres. This may be in a BINARY or TEXT format depending /// on the data type and if the query was prepared or not. +#[derive(Debug)] pub enum PgValue<'c> { Binary(&'c [u8]), Text(&'c str), diff --git a/sqlx-core/src/postgres/types/mod.rs b/sqlx-core/src/postgres/types/mod.rs index f017ac7f..ab202c40 100644 --- a/sqlx-core/src/postgres/types/mod.rs +++ b/sqlx-core/src/postgres/types/mod.rs @@ -11,8 +11,11 @@ mod bool; mod bytes; mod float; mod int; +mod record; mod str; +pub use self::record::{PgRecordDecoder, PgRecordEncoder}; + #[cfg(feature = "chrono")] mod chrono; @@ -57,6 +60,10 @@ impl PgTypeInfo { _ => None, } } + + pub fn oid(&self) -> u32 { + self.id.0 + } } impl Display for PgTypeInfo { diff --git a/sqlx-core/src/postgres/types/record.rs b/sqlx-core/src/postgres/types/record.rs new file mode 100644 index 00000000..c4955704 --- /dev/null +++ b/sqlx-core/src/postgres/types/record.rs @@ -0,0 +1,309 @@ +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::io::Buf; +use crate::postgres::protocol::TypeId; +use crate::postgres::{PgTypeInfo, PgValue, Postgres}; +use crate::types::Type; +use byteorder::BigEndian; +use std::convert::TryInto; + +pub struct PgRecordEncoder<'a> { + buf: &'a mut Vec, + beg: usize, + num: u32, +} + +impl<'a> PgRecordEncoder<'a> { + pub fn new(buf: &'a mut Vec) -> Self { + // reserve space for a field count + buf.extend_from_slice(&(0_u32).to_be_bytes()); + + Self { + beg: buf.len(), + buf, + num: 0, + } + } + + pub fn finish(&mut self) { + // replaces zeros with actual length + self.buf[self.beg - 4..self.beg].copy_from_slice(&self.num.to_be_bytes()); + } + + pub fn encode(&mut self, value: T) -> &mut Self + where + T: Type + Encode, + { + // write oid + let info = T::type_info(); + self.buf.extend(&info.oid().to_be_bytes()); + + // write zeros for length + self.buf.extend(&[0; 4]); + + let start = self.buf.len(); + if let IsNull::Yes = value.encode_nullable(self.buf) { + // replaces zeros with actual length + self.buf[start - 4..start].copy_from_slice(&(-1_i32).to_be_bytes()); + } else { + let end = self.buf.len(); + let size = end - start; + + // replaces zeros with actual length + self.buf[start - 4..start].copy_from_slice(&(size as u32).to_be_bytes()); + } + + // keep track of count + self.num += 1; + + self + } +} + +pub struct PgRecordDecoder<'de> { + value: PgValue<'de>, +} + +impl<'de> PgRecordDecoder<'de> { + pub fn new(value: Option>) -> crate::Result { + let mut value: PgValue = value.try_into()?; + + match value { + PgValue::Binary(ref mut buf) => { + let _expected_len = buf.get_u32::()?; + } + + PgValue::Text(ref mut s) => { + // remove outer ( ... ) + *s = &s[1..(s.len() - 1)]; + } + } + + Ok(Self { value }) + } + + pub fn decode(&mut self) -> crate::Result + where + T: Decode<'de, Postgres>, + { + match self.value { + PgValue::Binary(ref mut buf) => { + // TODO: We should fail if this type is not _compatible_; but + // I want to make sure we handle this _and_ the outer level + // type mismatch errors at the same time + let _oid = buf.get_u32::()?; + let len = buf.get_i32::()? as isize; + + let value = if len < 0 { + T::decode(None)? + } else { + let value_buf = &buf[..(len as usize)]; + *buf = &buf[(len as usize)..]; + + T::decode(Some(PgValue::Binary(value_buf)))? + }; + + Ok(value) + } + + PgValue::Text(ref mut s) => { + let mut in_quotes = false; + let mut in_escape = false; + let mut is_quoted = false; + let mut prev_ch = '\0'; + let mut eos = false; + let mut prev_index = 0; + let mut value = String::new(); + + let index = 'outer: loop { + let mut iter = s.char_indices(); + while let Some((index, ch)) = iter.next() { + match ch { + ',' if !in_quotes => { + break 'outer Some(prev_index); + } + + ',' if prev_ch == '\0' => { + break 'outer None; + } + + '"' if prev_ch == '"' && index != 1 => { + // Quotes are escaped with another quote + in_quotes = false; + value.push('"'); + } + + '"' if in_quotes => { + in_quotes = false; + } + + '\'' if in_escape => { + in_escape = false; + value.push('\''); + } + + '"' if in_escape => { + in_escape = false; + value.push('"'); + } + + '\\' if in_escape => { + in_escape = false; + value.push('\\'); + } + + '\\' => { + in_escape = true; + } + + '"' => { + is_quoted = true; + in_quotes = true; + } + + ch => { + value.push(ch); + } + } + + prev_index = index; + prev_ch = ch; + } + + eos = true; + + break 'outer if prev_ch == '\0' { + // NULL values have zero characters + // Empty strings are "" + None + } else { + Some(prev_index) + }; + }; + + let value = index.map(|index| { + let mut s = &s[..=index]; + + if is_quoted { + s = &s[1..s.len() - 1]; + } + + PgValue::Text(s) + }); + + let value = T::decode(value)?; + + if !eos { + *s = &s[index.unwrap_or(0) + 2..]; + } else { + *s = ""; + } + + Ok(value) + } + } + } +} + +macro_rules! impl_pg_record_for_tuple { + ($( $idx:ident : $T:ident ),+) => { + impl<$($T,)+> Type for ($($T,)+) { + #[inline] + fn type_info() -> PgTypeInfo { + PgTypeInfo { + id: TypeId(2249), + name: Some("RECORD".into()), + } + } + } + + impl<'de, $($T,)+> Decode<'de, Postgres> for ($($T,)+) + where + $($T: crate::types::Type,)+ + $($T: crate::decode::Decode<'de, Postgres>,)+ + { + fn decode(value: Option>) -> crate::Result { + let mut decoder = PgRecordDecoder::new(value)?; + + $(let $idx: $T = decoder.decode()?;)+ + + Ok(($($idx,)+)) + } + } + }; +} + +impl_pg_record_for_tuple!(_1: T1); + +impl_pg_record_for_tuple!(_1: T1, _2: T2); + +impl_pg_record_for_tuple!(_1: T1, _2: T2, _3: T3); + +impl_pg_record_for_tuple!(_1: T1, _2: T2, _3: T3, _4: T4); + +impl_pg_record_for_tuple!(_1: T1, _2: T2, _3: T3, _4: T4, _5: T5); + +impl_pg_record_for_tuple!(_1: T1, _2: T2, _3: T3, _4: T4, _5: T5, _6: T6); + +impl_pg_record_for_tuple!(_1: T1, _2: T2, _3: T3, _4: T4, _5: T5, _6: T6, _7: T7); + +impl_pg_record_for_tuple!( + _1: T1, + _2: T2, + _3: T3, + _4: T4, + _5: T5, + _6: T6, + _7: T7, + _8: T8 +); + +impl_pg_record_for_tuple!( + _1: T1, + _2: T2, + _3: T3, + _4: T4, + _5: T5, + _6: T6, + _7: T7, + _8: T8, + _9: T9 +); + +#[test] +fn test_encode_field() { + let value = "Foo Bar"; + let mut raw_encoded = Vec::new(); + <&str as Encode>::encode(&value, &mut raw_encoded); + let mut field_encoded = Vec::new(); + + let mut encoder = PgRecordEncoder::new(&mut field_encoded); + encoder.encode(&value); + + // check oid + let oid = <&str as Type>::type_info().oid(); + let field_encoded_oid = u32::from_be_bytes(field_encoded[4..8].try_into().unwrap()); + assert_eq!(oid, field_encoded_oid); + + // check length + let field_encoded_length = u32::from_be_bytes(field_encoded[8..12].try_into().unwrap()); + assert_eq!(raw_encoded.len(), field_encoded_length as usize); + + // check data + assert_eq!(raw_encoded, &field_encoded[12..]); +} + +#[test] +fn test_decode_field() { + let value = "Foo Bar".to_string(); + + let mut buf = Vec::new(); + let mut encoder = PgRecordEncoder::new(&mut buf); + encoder.encode(&value); + + let mut buf = buf.as_slice(); + let mut decoder = PgRecordDecoder::new(Some(PgValue::Binary(buf))).unwrap(); + + let value_decoded: String = decoder.decode().unwrap(); + assert_eq!(value_decoded, value); +} diff --git a/sqlx-core/src/sqlite/arguments.rs b/sqlx-core/src/sqlite/arguments.rs index 28099fef..8d4c6618 100644 --- a/sqlx-core/src/sqlite/arguments.rs +++ b/sqlx-core/src/sqlite/arguments.rs @@ -9,7 +9,7 @@ use libsqlite3_sys::{ }; use crate::arguments::Arguments; -use crate::encode::Encode; +use crate::encode::{Encode, IsNull}; use crate::sqlite::statement::Statement; use crate::sqlite::Sqlite; use crate::sqlite::SqliteError; @@ -63,7 +63,9 @@ impl Arguments for SqliteArguments { where T: Encode + Type, { - value.encode(&mut self.values); + if let IsNull::Yes = value.encode_nullable(&mut self.values) { + self.values.push(SqliteArgumentValue::Null); + } } } diff --git a/sqlx-macros/src/derives.rs b/sqlx-macros/src/derives.rs deleted file mode 100644 index b4e73a07..00000000 --- a/sqlx-macros/src/derives.rs +++ /dev/null @@ -1,108 +0,0 @@ -use quote::quote; -use syn::{parse_quote, Data, DataStruct, DeriveInput, Fields, FieldsUnnamed}; - -pub(crate) fn expand_derive_encode(input: DeriveInput) -> syn::Result { - match &input.data { - Data::Struct(DataStruct { - fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), - .. - }) if unnamed.len() == 1 => { - let ident = &input.ident; - let ty = &unnamed.first().unwrap().ty; - - // extract type generics - let generics = &input.generics; - let (_, ty_generics, _) = generics.split_for_impl(); - - // add db type for impl generics & where clause - let mut generics = generics.clone(); - generics.params.insert(0, parse_quote!(DB: sqlx::Database)); - generics - .make_where_clause() - .predicates - .push(parse_quote!(#ty: sqlx::encode::Encode)); - let (impl_generics, _, where_clause) = generics.split_for_impl(); - - Ok(quote!( - impl #impl_generics sqlx::encode::Encode for #ident #ty_generics #where_clause { - fn encode(&self, buf: &mut ::RawBuffer) { - sqlx::encode::Encode::encode(&self.0, buf) - } - fn encode_nullable(&self, buf: &mut ::RawBuffer) -> sqlx::encode::IsNull { - sqlx::encode::Encode::encode_nullable(&self.0, buf) - } - fn size_hint(&self) -> usize { - sqlx::encode::Encode::size_hint(&self.0) - } - } - )) - } - _ => Err(syn::Error::new_spanned( - input, - "expected a tuple struct with a single field", - )), - } -} - -pub(crate) fn expand_derive_decode(input: DeriveInput) -> syn::Result { - match &input.data { - Data::Struct(DataStruct { - fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), - .. - }) if unnamed.len() == 1 => { - let ident = &input.ident; - let ty = &unnamed.first().unwrap().ty; - - // extract type generics - let generics = &input.generics; - let (_, ty_generics, _) = generics.split_for_impl(); - - let mut impls = Vec::new(); - - if cfg!(feature = "postgres") { - let mut generics = generics.clone(); - generics.params.insert(0, parse_quote!('de)); - generics - .make_where_clause() - .predicates - .push(parse_quote!(#ty: sqlx::decode::Decode<'de, sqlx::Postgres>)); - - let (impl_generics, _, where_clause) = generics.split_for_impl(); - - impls.push(quote!( - impl #impl_generics sqlx::decode::Decode<'de, sqlx::Postgres> for #ident #ty_generics #where_clause { - fn decode(value: >::RawValue) -> sqlx::Result { - <#ty as sqlx::decode::Decode<'de, sqlx::Postgres>>::decode(value).map(Self) - } - } - )); - } - - if cfg!(feature = "mysql") { - let mut generics = generics.clone(); - generics.params.insert(0, parse_quote!('de)); - generics - .make_where_clause() - .predicates - .push(parse_quote!(#ty: sqlx::decode::Decode<'de, sqlx::MySql>)); - - let (impl_generics, _, where_clause) = generics.split_for_impl(); - - impls.push(quote!( - impl #impl_generics sqlx::decode::Decode<'de, sqlx::MySql> for #ident #ty_generics #where_clause { - fn decode(value: >::RawValue) -> sqlx::Result { - <#ty as sqlx::decode::Decode<'de, sqlx::MySql>>::decode(value).map(Self) - } - } - )); - } - - // panic!("{}", q) - Ok(quote!(#(#impls)*)) - } - _ => Err(syn::Error::new_spanned( - input, - "expected a tuple struct with a single field", - )), - } -} diff --git a/sqlx-macros/src/derives/attributes.rs b/sqlx-macros/src/derives/attributes.rs new file mode 100644 index 00000000..3c3d1f66 --- /dev/null +++ b/sqlx-macros/src/derives/attributes.rs @@ -0,0 +1,291 @@ +use proc_macro2::Ident; +use syn::punctuated::Punctuated; +use syn::token::Comma; +use syn::{Attribute, DeriveInput, Field, Lit, Meta, MetaNameValue, NestedMeta, Variant}; + +macro_rules! assert_attribute { + ($e:expr, $err:expr, $input:expr) => { + if !$e { + return Err(syn::Error::new_spanned($input, $err)); + } + }; +} + +macro_rules! fail { + ($t:expr, $m:expr) => { + return Err(syn::Error::new_spanned($t, $m)); + }; +} + +macro_rules! try_set { + ($i:ident, $v:expr, $t:expr) => { + match $i { + None => $i = Some($v), + Some(_) => fail!($t, "duplicate attribute"), + } + }; +} + +#[derive(Copy, Clone)] +pub enum RenameAll { + LowerCase, +} + +pub struct SqlxContainerAttributes { + pub transparent: bool, + pub postgres_oid: Option, + pub rename_all: Option, + pub repr: Option, +} + +pub struct SqlxChildAttributes { + pub rename: Option, +} + +pub fn parse_container_attributes(input: &[Attribute]) -> syn::Result { + let mut transparent = None; + let mut postgres_oid = None; + let mut repr = None; + let mut rename_all = None; + + for attr in input { + let meta = attr + .parse_meta() + .map_err(|e| syn::Error::new_spanned(attr, e))?; + match meta { + Meta::List(list) if list.path.is_ident("sqlx") => { + for value in list.nested.iter() { + match value { + NestedMeta::Meta(meta) => match meta { + Meta::Path(p) if p.is_ident("transparent") => { + try_set!(transparent, true, value) + } + + Meta::NameValue(MetaNameValue { + path, + lit: Lit::Str(val), + .. + }) if path.is_ident("rename_all") => { + let val = match &*val.value() { + "lowercase" => RenameAll::LowerCase, + + _ => fail!(meta, "unexpected value for rename_all"), + }; + + try_set!(rename_all, val, value) + } + + Meta::List(list) if list.path.is_ident("postgres") => { + for value in list.nested.iter() { + match value { + NestedMeta::Meta(Meta::NameValue(MetaNameValue { + path, + lit: Lit::Int(val), + .. + })) if path.is_ident("oid") => { + try_set!(postgres_oid, val.base10_parse()?, value); + } + u => fail!(u, "unexpected value"), + } + } + } + + u => fail!(u, "unexpected attribute"), + }, + u => fail!(u, "unexpected attribute"), + } + } + } + Meta::List(list) if list.path.is_ident("repr") => { + if list.nested.len() != 1 { + fail!(&list.nested, "expected one value") + } + match list.nested.first().unwrap() { + NestedMeta::Meta(Meta::Path(p)) if p.get_ident().is_some() => { + try_set!(repr, p.get_ident().unwrap().clone(), list); + } + u => fail!(u, "unexpected value"), + } + } + _ => {} + } + } + + Ok(SqlxContainerAttributes { + transparent: transparent.unwrap_or(false), + postgres_oid, + repr, + rename_all, + }) +} + +pub fn parse_child_attributes(input: &[Attribute]) -> syn::Result { + let mut rename = None; + + for attr in input { + let meta = attr + .parse_meta() + .map_err(|e| syn::Error::new_spanned(attr, e))?; + + match meta { + Meta::List(list) if list.path.is_ident("sqlx") => { + for value in list.nested.iter() { + match value { + NestedMeta::Meta(meta) => match meta { + Meta::NameValue(MetaNameValue { + path, + lit: Lit::Str(val), + .. + }) if path.is_ident("rename") => try_set!(rename, val.value(), value), + + u => fail!(u, "unexpected attribute"), + }, + u => fail!(u, "unexpected attribute"), + } + } + } + _ => {} + } + } + + Ok(SqlxChildAttributes { rename }) +} + +pub fn check_transparent_attributes(input: &DeriveInput, field: &Field) -> syn::Result<()> { + let attributes = parse_container_attributes(&input.attrs)?; + + assert_attribute!( + attributes.transparent, + "expected #[sqlx(transparent)]", + input + ); + + #[cfg(feature = "postgres")] + assert_attribute!( + attributes.postgres_oid.is_none(), + "unexpected #[sqlx(postgres(oid = ..))]", + input + ); + + assert_attribute!( + attributes.rename_all.is_none(), + "unexpected #[sqlx(rename_all = ..)]", + field + ); + + assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input); + + let attributes = parse_child_attributes(&field.attrs)?; + + assert_attribute!( + attributes.rename.is_none(), + "unexpected #[sqlx(rename = ..)]", + field + ); + + Ok(()) +} + +pub fn check_enum_attributes<'a>(input: &'a DeriveInput) -> syn::Result { + let attributes = parse_container_attributes(&input.attrs)?; + + assert_attribute!( + !attributes.transparent, + "unexpected #[sqlx(transparent)]", + input + ); + + Ok(attributes) +} + +pub fn check_weak_enum_attributes( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let attributes = check_enum_attributes(input)?; + + #[cfg(feature = "postgres")] + assert_attribute!( + attributes.postgres_oid.is_none(), + "unexpected #[sqlx(postgres(oid = ..))]", + input + ); + + assert_attribute!(attributes.repr.is_some(), "expected #[repr(..)]", input); + + assert_attribute!( + attributes.rename_all.is_none(), + "unexpected #[sqlx(c = ..)]", + input + ); + + for variant in variants { + let attributes = parse_child_attributes(&variant.attrs)?; + + assert_attribute!( + attributes.rename.is_none(), + "unexpected #[sqlx(rename = ..)]", + variant + ); + } + + Ok(attributes) +} + +pub fn check_strong_enum_attributes( + input: &DeriveInput, + _variants: &Punctuated, +) -> syn::Result { + let attributes = check_enum_attributes(input)?; + + #[cfg(feature = "postgres")] + assert_attribute!( + attributes.postgres_oid.is_some(), + "expected #[sqlx(postgres(oid = ..))]", + input + ); + + assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input); + + Ok(attributes) +} + +pub fn check_struct_attributes<'a>( + input: &'a DeriveInput, + fields: &Punctuated, +) -> syn::Result { + let attributes = parse_container_attributes(&input.attrs)?; + + assert_attribute!( + !attributes.transparent, + "unexpected #[sqlx(transparent)]", + input + ); + + #[cfg(feature = "postgres")] + assert_attribute!( + attributes.postgres_oid.is_some(), + "expected #[sqlx(postgres(oid = ..))]", + input + ); + + assert_attribute!( + attributes.rename_all.is_none(), + "unexpected #[sqlx(rename_all = ..)]", + input + ); + + assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input); + + for field in fields { + let attributes = parse_child_attributes(&field.attrs)?; + + assert_attribute!( + attributes.rename.is_none(), + "unexpected #[sqlx(rename = ..)]", + field + ); + } + + Ok(attributes) +} diff --git a/sqlx-macros/src/derives/decode.rs b/sqlx-macros/src/derives/decode.rs new file mode 100644 index 00000000..5b599019 --- /dev/null +++ b/sqlx-macros/src/derives/decode.rs @@ -0,0 +1,212 @@ +use super::attributes::{ + check_strong_enum_attributes, check_struct_attributes, check_transparent_attributes, + check_weak_enum_attributes, parse_child_attributes, parse_container_attributes, +}; +use super::rename_all; +use quote::quote; +use syn::punctuated::Punctuated; +use syn::token::Comma; +use syn::{ + parse_quote, Arm, Data, DataEnum, DataStruct, DeriveInput, Field, Fields, FieldsNamed, + FieldsUnnamed, Stmt, Variant, +}; + +pub fn expand_derive_decode(input: &DeriveInput) -> syn::Result { + let attrs = parse_container_attributes(&input.attrs)?; + match &input.data { + Data::Struct(DataStruct { + fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), + .. + }) if unnamed.len() == 1 => { + expand_derive_decode_transparent(input, unnamed.first().unwrap()) + } + Data::Enum(DataEnum { variants, .. }) => match attrs.repr { + Some(_) => expand_derive_decode_weak_enum(input, variants), + None => expand_derive_decode_strong_enum(input, variants), + }, + Data::Struct(DataStruct { + fields: Fields::Named(FieldsNamed { named, .. }), + .. + }) => expand_derive_decode_struct(input, named), + Data::Union(_) => Err(syn::Error::new_spanned(input, "unions are not supported")), + Data::Struct(DataStruct { + fields: Fields::Unnamed(..), + .. + }) => Err(syn::Error::new_spanned( + input, + "structs with zero or more than one unnamed field are not supported", + )), + Data::Struct(DataStruct { + fields: Fields::Unit, + .. + }) => Err(syn::Error::new_spanned( + input, + "unit structs are not supported", + )), + } +} + +fn expand_derive_decode_transparent( + input: &DeriveInput, + field: &Field, +) -> syn::Result { + check_transparent_attributes(input, field)?; + + let ident = &input.ident; + let ty = &field.ty; + + // extract type generics + let generics = &input.generics; + let (_, ty_generics, _) = generics.split_for_impl(); + + // add db type for impl generics & where clause + let mut generics = generics.clone(); + generics.params.insert(0, parse_quote!(DB: sqlx::Database)); + generics.params.insert(0, parse_quote!('de)); + generics + .make_where_clause() + .predicates + .push(parse_quote!(#ty: sqlx::decode::Decode<'de, DB>)); + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + let tts = quote!( + impl #impl_generics sqlx::decode::Decode<'de, DB> for #ident #ty_generics #where_clause { + fn decode(value: >::RawValue) -> sqlx::Result { + <#ty as sqlx::decode::Decode<'de, DB>>::decode(value).map(Self) + } + } + ); + + Ok(tts) +} + +fn expand_derive_decode_weak_enum( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let attr = check_weak_enum_attributes(input, &variants)?; + let repr = attr.repr.unwrap(); + + let ident = &input.ident; + let ident_s = ident.to_string(); + + let arms = variants + .iter() + .map(|v| { + let id = &v.ident; + parse_quote!(_ if (#ident :: #id as #repr) == value => Ok(#ident :: #id),) + }) + .collect::>(); + + Ok(quote!( + impl<'de, DB: sqlx::Database> sqlx::decode::Decode<'de, DB> for #ident where #repr: sqlx::decode::Decode<'de, DB> { + fn decode(value: >::RawValue) -> sqlx::Result { + let value = <#repr as sqlx::decode::Decode<'de, DB>>::decode(value)?; + + match value { + #(#arms)* + + _ => Err(sqlx::Error::Decode(format!("invalid value {:?} for enum {}", value, #ident_s).into())) + } + } + } + )) +} + +fn expand_derive_decode_strong_enum( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let cattr = check_strong_enum_attributes(input, &variants)?; + + let ident = &input.ident; + let ident_s = ident.to_string(); + + let value_arms = variants.iter().map(|v| -> Arm { + let id = &v.ident; + let attributes = parse_child_attributes(&v.attrs).unwrap(); + + if let Some(rename) = attributes.rename { + parse_quote!(#rename => Ok(#ident :: #id),) + } else if let Some(pattern) = cattr.rename_all { + let name = rename_all(&*id.to_string(), pattern); + + parse_quote!(#name => Ok(#ident :: #id),) + } else { + let name = id.to_string(); + parse_quote!(#name => Ok(#ident :: #id),) + } + }); + + Ok(quote!( + impl<'de, DB: sqlx::Database> sqlx::decode::Decode<'de, DB> for #ident where &'de str: sqlx::decode::Decode<'de, DB> { + fn decode(value: >::RawValue) -> sqlx::Result { + let value = <&'de str as sqlx::decode::Decode<'de, DB>>::decode(value)?; + match value { + #(#value_arms)* + + _ => Err(sqlx::Error::Decode(format!("invalid value {:?} for enum {}", value, #ident_s).into())) + } + } + } + )) +} + +fn expand_derive_decode_struct( + input: &DeriveInput, + fields: &Punctuated, +) -> syn::Result { + check_struct_attributes(input, fields)?; + + let mut tts = proc_macro2::TokenStream::new(); + + if cfg!(feature = "postgres") { + let ident = &input.ident; + + // extract type generics + let generics = &input.generics; + let (_, ty_generics, _) = generics.split_for_impl(); + + // add db type for impl generics & where clause + let mut generics = generics.clone(); + generics.params.insert(0, parse_quote!('de)); + + let predicates = &mut generics.make_where_clause().predicates; + + for field in fields { + let ty = &field.ty; + + predicates.push(parse_quote!(#ty: sqlx::decode::Decode<'de, sqlx::Postgres>)); + predicates.push(parse_quote!(#ty: sqlx::types::Type)); + } + + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + let reads = fields.iter().map(|field| -> Stmt { + let id = &field.ident; + let ty = &field.ty; + + parse_quote!( + let #id = decoder.decode::<#ty>()?; + ) + }); + + let names = fields.iter().map(|field| &field.ident); + + tts.extend(quote!( + impl #impl_generics sqlx::decode::Decode<'de, sqlx::Postgres> for #ident #ty_generics #where_clause { + fn decode(value: >::RawValue) -> sqlx::Result { + let mut decoder = sqlx::postgres::types::PgRecordDecoder::new(value)?; + + #(#reads)* + + Ok(#ident { + #(#names),* + }) + } + } + )); + } + + Ok(tts) +} diff --git a/sqlx-macros/src/derives/encode.rs b/sqlx-macros/src/derives/encode.rs new file mode 100644 index 00000000..e6b2c237 --- /dev/null +++ b/sqlx-macros/src/derives/encode.rs @@ -0,0 +1,223 @@ +use super::attributes::{ + check_strong_enum_attributes, check_struct_attributes, check_transparent_attributes, + check_weak_enum_attributes, parse_child_attributes, parse_container_attributes, +}; +use super::rename_all; +use quote::quote; +use syn::punctuated::Punctuated; +use syn::token::Comma; +use syn::{ + parse_quote, Data, DataEnum, DataStruct, DeriveInput, Expr, Field, Fields, FieldsNamed, + FieldsUnnamed, Stmt, Variant, +}; + +pub fn expand_derive_encode(input: &DeriveInput) -> syn::Result { + let args = parse_container_attributes(&input.attrs)?; + + match &input.data { + Data::Struct(DataStruct { + fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), + .. + }) if unnamed.len() == 1 => { + expand_derive_encode_transparent(&input, unnamed.first().unwrap()) + } + Data::Enum(DataEnum { variants, .. }) => match args.repr { + Some(_) => expand_derive_encode_weak_enum(input, variants), + None => expand_derive_encode_strong_enum(input, variants), + }, + Data::Struct(DataStruct { + fields: Fields::Named(FieldsNamed { named, .. }), + .. + }) => expand_derive_encode_struct(input, named), + Data::Union(_) => Err(syn::Error::new_spanned(input, "unions are not supported")), + Data::Struct(DataStruct { + fields: Fields::Unnamed(..), + .. + }) => Err(syn::Error::new_spanned( + input, + "structs with zero or more than one unnamed field are not supported", + )), + Data::Struct(DataStruct { + fields: Fields::Unit, + .. + }) => Err(syn::Error::new_spanned( + input, + "unit structs are not supported", + )), + } +} + +fn expand_derive_encode_transparent( + input: &DeriveInput, + field: &Field, +) -> syn::Result { + check_transparent_attributes(input, field)?; + + let ident = &input.ident; + let ty = &field.ty; + + // extract type generics + let generics = &input.generics; + let (_, ty_generics, _) = generics.split_for_impl(); + + // add db type for impl generics & where clause + let mut generics = generics.clone(); + generics.params.insert(0, parse_quote!(DB: sqlx::Database)); + generics + .make_where_clause() + .predicates + .push(parse_quote!(#ty: sqlx::encode::Encode)); + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + Ok(quote!( + impl #impl_generics sqlx::encode::Encode for #ident #ty_generics #where_clause { + fn encode(&self, buf: &mut DB::RawBuffer) { + sqlx::encode::Encode::encode(&self.0, buf) + } + fn encode_nullable(&self, buf: &mut DB::RawBuffer) -> sqlx::encode::IsNull { + sqlx::encode::Encode::encode_nullable(&self.0, buf) + } + fn size_hint(&self) -> usize { + sqlx::encode::Encode::size_hint(&self.0) + } + } + )) +} + +fn expand_derive_encode_weak_enum( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let attr = check_weak_enum_attributes(input, &variants)?; + let repr = attr.repr.unwrap(); + + let ident = &input.ident; + + Ok(quote!( + impl sqlx::encode::Encode for #ident where #repr: sqlx::encode::Encode { + fn encode(&self, buf: &mut DB::RawBuffer) { + sqlx::encode::Encode::encode(&(*self as #repr), buf) + } + + fn encode_nullable(&self, buf: &mut DB::RawBuffer) -> sqlx::encode::IsNull { + sqlx::encode::Encode::encode_nullable(&(*self as #repr), buf) + } + + fn size_hint(&self) -> usize { + sqlx::encode::Encode::size_hint(&(*self as #repr)) + } + } + )) +} + +fn expand_derive_encode_strong_enum( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let cattr = check_strong_enum_attributes(input, &variants)?; + + let ident = &input.ident; + + let mut value_arms = Vec::new(); + for v in variants { + let id = &v.ident; + let attributes = parse_child_attributes(&v.attrs)?; + + if let Some(rename) = attributes.rename { + value_arms.push(quote!(#ident :: #id => #rename,)); + } else if let Some(pattern) = cattr.rename_all { + let name = rename_all(&*id.to_string(), pattern); + + value_arms.push(quote!(#ident :: #id => #name,)); + } else { + let name = id.to_string(); + value_arms.push(quote!(#ident :: #id => #name,)); + } + } + + Ok(quote!( + impl sqlx::encode::Encode for #ident where str: sqlx::encode::Encode { + fn encode(&self, buf: &mut DB::RawBuffer) { + let val = match self { + #(#value_arms)* + }; + >::encode(val, buf) + } + + fn size_hint(&self) -> usize { + let val = match self { + #(#value_arms)* + }; + >::size_hint(val) + } + } + )) +} + +fn expand_derive_encode_struct( + input: &DeriveInput, + fields: &Punctuated, +) -> syn::Result { + check_struct_attributes(input, &fields)?; + + let mut tts = proc_macro2::TokenStream::new(); + + if cfg!(feature = "postgres") { + let ident = &input.ident; + let column_count = fields.len(); + + // extract type generics + let generics = &input.generics; + let (_, ty_generics, _) = generics.split_for_impl(); + + // add db type for impl generics & where clause + let mut generics = generics.clone(); + let predicates = &mut generics.make_where_clause().predicates; + + for field in fields { + let ty = &field.ty; + + predicates.push(parse_quote!(#ty: sqlx::encode::Encode)); + predicates.push(parse_quote!(#ty: sqlx::types::Type)); + } + + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + let writes = fields.iter().map(|field| -> Stmt { + let id = &field.ident; + + parse_quote!( + // sqlx::postgres::encode_struct_field(buf, &self. #id); + encoder.encode(&self. #id); + ) + }); + + let sizes = fields.iter().map(|field| -> Expr { + let id = &field.ident; + let ty = &field.ty; + + parse_quote!( + <#ty as sqlx::encode::Encode>::size_hint(&self. #id) + ) + }); + + tts.extend(quote!( + impl #impl_generics sqlx::encode::Encode for #ident #ty_generics #where_clause { + fn encode(&self, buf: &mut std::vec::Vec) { + let mut encoder = sqlx::postgres::types::PgRecordEncoder::new(buf); + + #(#writes)* + + encoder.finish(); + } + + fn size_hint(&self) -> usize { + #column_count * (4 + 4) // oid (int) and length (int) for each column + + #(#sizes)+* // sum of the size hints for each column + } + } + )); + } + + Ok(tts) +} diff --git a/sqlx-macros/src/derives/mod.rs b/sqlx-macros/src/derives/mod.rs new file mode 100644 index 00000000..2d8208f2 --- /dev/null +++ b/sqlx-macros/src/derives/mod.rs @@ -0,0 +1,32 @@ +mod attributes; +mod decode; +mod encode; +mod r#type; + +pub(crate) use decode::expand_derive_decode; +pub(crate) use encode::expand_derive_encode; +pub(crate) use r#type::expand_derive_type; + +use self::attributes::RenameAll; +use std::iter::FromIterator; +use syn::DeriveInput; + +pub(crate) fn expand_derive_type_encode_decode( + input: &DeriveInput, +) -> syn::Result { + let encode_tts = expand_derive_encode(input)?; + let decode_tts = expand_derive_decode(input)?; + let type_tts = expand_derive_type(input)?; + + let combined = proc_macro2::TokenStream::from_iter( + encode_tts.into_iter().chain(decode_tts).chain(type_tts), + ); + + Ok(combined) +} + +pub(crate) fn rename_all(s: &str, pattern: RenameAll) -> String { + match pattern { + RenameAll::LowerCase => s.to_lowercase(), + } +} diff --git a/sqlx-macros/src/derives/type.rs b/sqlx-macros/src/derives/type.rs new file mode 100644 index 00000000..9fbec169 --- /dev/null +++ b/sqlx-macros/src/derives/type.rs @@ -0,0 +1,156 @@ +use super::attributes::{ + check_strong_enum_attributes, check_struct_attributes, check_transparent_attributes, + check_weak_enum_attributes, parse_container_attributes, +}; +use quote::quote; +use syn::punctuated::Punctuated; +use syn::token::Comma; +use syn::{ + parse_quote, Data, DataEnum, DataStruct, DeriveInput, Field, Fields, FieldsNamed, + FieldsUnnamed, Variant, +}; + +pub fn expand_derive_type(input: &DeriveInput) -> syn::Result { + let attrs = parse_container_attributes(&input.attrs)?; + match &input.data { + Data::Struct(DataStruct { + fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), + .. + }) if unnamed.len() == 1 => { + expand_derive_has_sql_type_transparent(input, unnamed.first().unwrap()) + } + Data::Enum(DataEnum { variants, .. }) => match attrs.repr { + Some(_) => expand_derive_has_sql_type_weak_enum(input, variants), + None => expand_derive_has_sql_type_strong_enum(input, variants), + }, + Data::Struct(DataStruct { + fields: Fields::Named(FieldsNamed { named, .. }), + .. + }) => expand_derive_has_sql_type_struct(input, named), + Data::Union(_) => Err(syn::Error::new_spanned(input, "unions are not supported")), + Data::Struct(DataStruct { + fields: Fields::Unnamed(..), + .. + }) => Err(syn::Error::new_spanned( + input, + "structs with zero or more than one unnamed field are not supported", + )), + Data::Struct(DataStruct { + fields: Fields::Unit, + .. + }) => Err(syn::Error::new_spanned( + input, + "unit structs are not supported", + )), + } +} + +fn expand_derive_has_sql_type_transparent( + input: &DeriveInput, + field: &Field, +) -> syn::Result { + check_transparent_attributes(input, field)?; + + let ident = &input.ident; + let ty = &field.ty; + + // extract type generics + let generics = &input.generics; + let (_, ty_generics, _) = generics.split_for_impl(); + + // add db type for clause + let mut generics = generics.clone(); + generics.params.insert(0, parse_quote!(DB: sqlx::Database)); + generics + .make_where_clause() + .predicates + .push(parse_quote!(#ty: sqlx::types::Type)); + + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + Ok(quote!( + impl #impl_generics sqlx::types::Type< DB > for #ident #ty_generics #where_clause { + fn type_info() -> DB::TypeInfo { + <#ty as sqlx::Type>::type_info() + } + } + )) +} + +fn expand_derive_has_sql_type_weak_enum( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let attr = check_weak_enum_attributes(input, variants)?; + let repr = attr.repr.unwrap(); + let ident = &input.ident; + + Ok(quote!( + impl sqlx::Type for #ident + where + #repr: sqlx::Type, + { + fn type_info() -> DB::TypeInfo { + <#repr as sqlx::Type>::type_info() + } + } + )) +} + +fn expand_derive_has_sql_type_strong_enum( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let attributes = check_strong_enum_attributes(input, variants)?; + + let ident = &input.ident; + let mut tts = proc_macro2::TokenStream::new(); + + if cfg!(feature = "mysql") { + tts.extend(quote!( + impl sqlx::Type< sqlx::MySql > for #ident { + fn type_info() -> sqlx::mysql::MySqlTypeInfo { + // This is really fine, MySQL is loosely typed and + // we don't nede to be specific here + >::type_info() + } + } + )); + } + + if cfg!(feature = "postgres") { + let oid = attributes.postgres_oid.unwrap(); + tts.extend(quote!( + impl sqlx::Type< sqlx::Postgres > for #ident { + fn type_info() -> sqlx::postgres::PgTypeInfo { + sqlx::postgres::PgTypeInfo::with_oid(#oid) + } + } + )); + } + + Ok(tts) +} + +fn expand_derive_has_sql_type_struct( + input: &DeriveInput, + fields: &Punctuated, +) -> syn::Result { + let attributes = check_struct_attributes(input, fields)?; + + let ident = &input.ident; + let mut tts = proc_macro2::TokenStream::new(); + + if cfg!(feature = "postgres") { + let oid = attributes.postgres_oid.unwrap(); + tts.extend(quote!( + impl sqlx::types::Type< sqlx::Postgres > for #ident { + fn type_info() -> sqlx::postgres::PgTypeInfo { + sqlx::postgres::PgTypeInfo::with_oid(#oid) + } + } + )); + } + + Ok(tts) +} diff --git a/sqlx-macros/src/lib.rs b/sqlx-macros/src/lib.rs index 0bed307f..d0a19995 100644 --- a/sqlx-macros/src/lib.rs +++ b/sqlx-macros/src/lib.rs @@ -151,19 +151,28 @@ pub fn query_file_as(input: TokenStream) -> TokenStream { async_macro!(db, input: QueryAsMacroInput => expand_query_file_as(input, db)) } -#[proc_macro_derive(Encode)] +#[proc_macro_derive(Encode, attributes(sqlx))] pub fn derive_encode(tokenstream: TokenStream) -> TokenStream { let input = syn::parse_macro_input!(tokenstream as syn::DeriveInput); - match derives::expand_derive_encode(input) { + match derives::expand_derive_encode(&input) { Ok(ts) => ts.into(), Err(e) => e.to_compile_error().into(), } } -#[proc_macro_derive(Decode)] +#[proc_macro_derive(Decode, attributes(sqlx))] pub fn derive_decode(tokenstream: TokenStream) -> TokenStream { let input = syn::parse_macro_input!(tokenstream as syn::DeriveInput); - match derives::expand_derive_decode(input) { + match derives::expand_derive_decode(&input) { + Ok(ts) => ts.into(), + Err(e) => e.to_compile_error().into(), + } +} + +#[proc_macro_derive(Type, attributes(sqlx))] +pub fn derive_type(tokenstream: TokenStream) -> TokenStream { + let input = syn::parse_macro_input!(tokenstream as syn::DeriveInput); + match derives::expand_derive_type_encode_decode(&input) { Ok(ts) => ts.into(), Err(e) => e.to_compile_error().into(), } diff --git a/src/lib.rs b/src/lib.rs index d1dc6aaa..3342b09b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -40,6 +40,9 @@ pub use sqlx_core::sqlite::{self, Sqlite, SqliteConnection, SqlitePool}; #[doc(hidden)] pub extern crate sqlx_macros; +#[cfg(feature = "macros")] +pub use sqlx_macros::Type; + #[cfg(feature = "macros")] mod macros; diff --git a/src/types.rs b/src/types.rs new file mode 100644 index 00000000..6eb3cab6 --- /dev/null +++ b/src/types.rs @@ -0,0 +1,6 @@ +//! Traits linking Rust types to SQL types. + +pub use sqlx_core::types::*; + +#[cfg(feature = "macros")] +pub use sqlx_macros::HasSqlType; diff --git a/tests/derives.rs b/tests/derives.rs deleted file mode 100644 index d5120031..00000000 --- a/tests/derives.rs +++ /dev/null @@ -1,77 +0,0 @@ -use sqlx::decode::Decode; -use sqlx::encode::Encode; - -#[derive(PartialEq, Debug, Encode, Decode)] -struct Foo(i32); - -#[test] -#[cfg(feature = "postgres")] -fn encode_with_postgres() { - use sqlx_core::postgres::Postgres; - - let example = Foo(0x1122_3344); - - let mut encoded = Vec::new(); - let mut encoded_orig = Vec::new(); - - Encode::::encode(&example, &mut encoded); - Encode::::encode(&example.0, &mut encoded_orig); - - assert_eq!(encoded, encoded_orig); -} - -#[test] -#[cfg(feature = "mysql")] -fn encode_with_mysql() { - use sqlx_core::mysql::MySql; - - let example = Foo(0x1122_3344); - - let mut encoded = Vec::new(); - let mut encoded_orig = Vec::new(); - - Encode::::encode(&example, &mut encoded); - Encode::::encode(&example.0, &mut encoded_orig); - - assert_eq!(encoded, encoded_orig); -} - -#[test] -#[cfg(feature = "mysql")] -fn decode_mysql() { - decode_with_db(); -} - -#[test] -#[cfg(feature = "postgres")] -fn decode_postgres() { - decode_with_db(); -} - -#[cfg(feature = "postgres")] -fn decode_with_db() -where - Foo: for<'de> Decode<'de, sqlx::Postgres> + Encode, -{ - let example = Foo(0x1122_3344); - - let mut encoded = Vec::new(); - Encode::::encode(&example, &mut encoded); - - let decoded = Foo::decode(Some(sqlx::postgres::PgValue::Binary(&encoded))).unwrap(); - assert_eq!(example, decoded); -} - -#[cfg(feature = "mysql")] -fn decode_with_db() -where - Foo: for<'de> Decode<'de, sqlx::MySql> + Encode, -{ - let example = Foo(0x1122_3344); - - let mut encoded = Vec::new(); - Encode::::encode(&example, &mut encoded); - - let decoded = Foo::decode(Some(sqlx::mysql::MySqlValue::Binary(&encoded))).unwrap(); - assert_eq!(example, decoded); -} diff --git a/tests/mysql-derives.rs b/tests/mysql-derives.rs new file mode 100644 index 00000000..0e875e8c --- /dev/null +++ b/tests/mysql-derives.rs @@ -0,0 +1,43 @@ +use sqlx::MySql; +use sqlx_test::test_type; +use std::fmt::Debug; + +// Transparent types are rust-side wrappers over DB types +#[derive(PartialEq, Debug, sqlx::Type)] +#[sqlx(transparent)] +struct Transparent(i32); + +// "Weak" enums map to an integer type indicated by #[repr] +#[derive(PartialEq, Copy, Clone, Debug, sqlx::Type)] +#[repr(i32)] +enum Weak { + One = 0, + Two = 2, + Three = 4, +} + +// "Strong" enums can map to TEXT or a custom enum +#[derive(PartialEq, Debug, sqlx::Type)] +#[sqlx(rename_all = "lowercase")] +enum Color { + Red, + Green, + Blue, +} + +test_type!(transparent( + MySql, + Transparent, + "0" == Transparent(0), + "23523" == Transparent(23523) +)); + +test_type!(weak_enum( + MySql, + Weak, + "0" == Weak::One, + "2" == Weak::Two, + "4" == Weak::Three +)); + +test_type!(strong_color_enum(MySql, Color, "'green'" == Color::Green)); diff --git a/tests/postgres-derives.rs b/tests/postgres-derives.rs new file mode 100644 index 00000000..11689b1b --- /dev/null +++ b/tests/postgres-derives.rs @@ -0,0 +1,86 @@ +use sqlx::Postgres; +use sqlx_test::test_type; +use std::fmt::Debug; + +// Transparent types are rust-side wrappers over DB types +#[derive(PartialEq, Debug, sqlx::Type)] +#[sqlx(transparent)] +struct Transparent(i32); + +// "Weak" enums map to an integer type indicated by #[repr] +#[derive(PartialEq, Copy, Clone, Debug, sqlx::Type)] +#[repr(i32)] +enum Weak { + One = 0, + Two = 2, + Three = 4, +} + +// "Strong" enums can map to TEXT (25) or a custom enum type +#[derive(PartialEq, Debug, sqlx::Type)] +#[sqlx(postgres(oid = 25))] +#[sqlx(rename_all = "lowercase")] +enum Strong { + One, + Two, + + #[sqlx(rename = "four")] + Three, +} + +// TODO: Figure out a good solution for custom type testing +// Records must map to a custom type +// Note that all types are types in Postgres +// #[derive(PartialEq, Debug, sqlx::Type)] +// #[sqlx(postgres(oid = 12184))] +// struct PgConfig { +// name: String, +// setting: Option, +// } + +test_type!(transparent( + Postgres, + Transparent, + "0" == Transparent(0), + "23523" == Transparent(23523) +)); + +test_type!(weak_enum( + Postgres, + Weak, + "0::int4" == Weak::One, + "2::int4" == Weak::Two, + "4::int4" == Weak::Three +)); + +test_type!(strong_enum( + Postgres, + Strong, + "'one'::text" == Strong::One, + "'two'::text" == Strong::Two, + "'four'::text" == Strong::Three +)); + +// TODO: Figure out a good solution for custom type testing +// test_type!(record_pg_config( +// Postgres, +// PgConfig, +// // (CC,gcc) +// "(SELECT ROW('CC', 'gcc')::pg_config)" +// == PgConfig { +// name: "CC".to_owned(), +// setting: Some("gcc".to_owned()), +// }, +// // (CC,) +// "(SELECT '(\"CC\",)'::pg_config)" +// == PgConfig { +// name: "CC".to_owned(), +// setting: None, +// }, +// // (CC,"") +// "(SELECT '(\"CC\",\"\")'::pg_config)" +// == PgConfig { +// name: "CC".to_owned(), +// setting: Some("".to_owned()), +// } +// )); diff --git a/tests/postgres-types.rs b/tests/postgres-types.rs index 991a5bf6..46b52427 100644 --- a/tests/postgres-types.rs +++ b/tests/postgres-types.rs @@ -1,5 +1,11 @@ -use sqlx::Postgres; -use sqlx_test::test_type; +use sqlx::decode::Decode; +use sqlx::encode::Encode; +use sqlx::postgres::types::PgRecordEncoder; +use sqlx::postgres::{PgQueryAs, PgTypeInfo, PgValue}; +use sqlx::{Cursor, Executor, Postgres, Row, Type}; +use sqlx_core::postgres::types::PgRecordDecoder; +use sqlx_test::{new, test_type}; +use std::sync::atomic::{AtomicU32, Ordering}; test_type!(null( Postgres, @@ -87,3 +93,188 @@ mod chrono { ) )); } + +#[cfg_attr(feature = "runtime-async-std", async_std::test)] +#[cfg_attr(feature = "runtime-tokio", tokio::test)] +async fn test_prepared_anonymous_record() -> anyhow::Result<()> { + let mut conn = new::().await?; + + // Tuple of no elements is not possible + // Tuple of 1 element requires a concrete type + // Tuple with a NULL requires a concrete type + + // Tuple of 2 elements + let rec: ((bool, i32),) = sqlx::query_as("SELECT (true, 23512)") + .fetch_one(&mut conn) + .await?; + + assert_eq!((rec.0).0, true); + assert_eq!((rec.0).1, 23512); + + // Tuple with an empty string + let rec: ((bool, String),) = sqlx::query_as("SELECT (true,'')") + .fetch_one(&mut conn) + .await?; + + assert_eq!((rec.0).1, ""); + + // Tuple with a string with an interior comma + let rec: ((bool, String),) = sqlx::query_as("SELECT (true,'Hello, World!')") + .fetch_one(&mut conn) + .await?; + + assert_eq!((rec.0).1, "Hello, World!"); + + // Tuple with a string with an emoji + let rec: ((bool, String),) = sqlx::query_as("SELECT (true,'Hello, 🐕!')") + .fetch_one(&mut conn) + .await?; + + assert_eq!((rec.0).1, "Hello, 🐕!"); + + Ok(()) +} + +#[cfg_attr(feature = "runtime-async-std", async_std::test)] +#[cfg_attr(feature = "runtime-tokio", tokio::test)] +async fn test_unprepared_anonymous_record() -> anyhow::Result<()> { + let mut conn = new::().await?; + + // Tuple of no elements is not possible + // Tuple of 1 element requires a concrete type + // Tuple with a NULL requires a concrete type + + // Tuple of 2 elements + let mut cursor = conn.fetch("SELECT (true, 23512)"); + let row = cursor.next().await?.unwrap(); + let rec: (bool, i32) = row.get(0); + + assert_eq!(rec.0, true); + assert_eq!(rec.1, 23512); + + // Tuple with an empty string + let mut cursor = conn.fetch("SELECT (true, '')"); + let row = cursor.next().await?.unwrap(); + let rec: (bool, String) = row.get(0); + + assert_eq!(rec.1, ""); + + // Tuple with a string with an interior comma + let mut cursor = conn.fetch("SELECT (true, 'Hello, World!')"); + let row = cursor.next().await?.unwrap(); + let rec: (bool, String) = row.get(0); + + assert_eq!(rec.1, "Hello, World!"); + + // Tuple with a string with an emoji + let mut cursor = conn.fetch("SELECT (true, 'Hello, 🐕!')"); + let row = cursor.next().await?.unwrap(); + let rec: (bool, String) = row.get(0); + + assert_eq!(rec.1, "Hello, 🐕!"); + + Ok(()) +} + +#[cfg_attr(feature = "runtime-async-std", async_std::test)] +#[cfg_attr(feature = "runtime-tokio", tokio::test)] +async fn test_prepared_structs() -> anyhow::Result<()> { + let mut conn = new::().await?; + + // + // Setup custom types if needed + // + + static OID_RECORD_EMPTY: AtomicU32 = AtomicU32::new(0); + static OID_RECORD_1: AtomicU32 = AtomicU32::new(0); + + conn.execute( + r#" +DO $$ BEGIN + CREATE TYPE _sqlx_record_empty AS (); + CREATE TYPE _sqlx_record_1 AS (_1 int8); +EXCEPTION + WHEN duplicate_object THEN null; +END $$; + "#, + ) + .await?; + + let type_ids: Vec<(i32,)> = sqlx::query_as( + "SELECT oid::int4 FROM pg_type WHERE typname IN ('_sqlx_record_empty', '_sqlx_record_1')", + ) + .fetch_all(&mut conn) + .await?; + + OID_RECORD_EMPTY.store(type_ids[0].0 as u32, Ordering::SeqCst); + OID_RECORD_1.store(type_ids[1].0 as u32, Ordering::SeqCst); + + // + // Record of no elements + // + + struct RecordEmpty {} + + impl Type for RecordEmpty { + fn type_info() -> PgTypeInfo { + PgTypeInfo::with_oid(OID_RECORD_EMPTY.load(Ordering::SeqCst)) + } + } + + impl Encode for RecordEmpty { + fn encode(&self, buf: &mut Vec) { + PgRecordEncoder::new(buf).finish(); + } + } + + impl<'de> Decode<'de, Postgres> for RecordEmpty { + fn decode(_value: Option>) -> sqlx::Result { + Ok(RecordEmpty {}) + } + } + + let _: (RecordEmpty, RecordEmpty) = sqlx::query_as("SELECT '()'::_sqlx_record_empty, $1") + .bind(RecordEmpty {}) + .fetch_one(&mut conn) + .await?; + + // + // Record of one element + // + + #[derive(Debug, PartialEq)] + struct Record1 { + _1: i64, + } + + impl Type for Record1 { + fn type_info() -> PgTypeInfo { + PgTypeInfo::with_oid(OID_RECORD_1.load(Ordering::SeqCst)) + } + } + + impl Encode for Record1 { + fn encode(&self, buf: &mut Vec) { + PgRecordEncoder::new(buf).encode(self._1).finish(); + } + } + + impl<'de> Decode<'de, Postgres> for Record1 { + fn decode(value: Option>) -> sqlx::Result { + let mut decoder = PgRecordDecoder::new(value)?; + + let _1 = decoder.decode()?; + + Ok(Record1 { _1 }) + } + } + + let rec: (Record1, Record1) = sqlx::query_as("SELECT '(324235)'::_sqlx_record_1, $1") + .bind(Record1 { _1: 324235 }) + .fetch_one(&mut conn) + .await?; + + assert_eq!(rec.0, rec.1); + + Ok(()) +} diff --git a/tests/postgres.rs b/tests/postgres.rs index 2eb1c922..9f11d87f 100644 --- a/tests/postgres.rs +++ b/tests/postgres.rs @@ -1,12 +1,13 @@ use futures::TryStreamExt; use sqlx::postgres::{PgPool, PgQueryAs, PgRow}; -use sqlx::{postgres::PgConnection, Connect, Connection, Executor, Row}; +use sqlx::{Connection, Executor, Postgres, Row}; +use sqlx_test::new; use std::time::Duration; #[cfg_attr(feature = "runtime-async-std", async_std::test)] #[cfg_attr(feature = "runtime-tokio", tokio::test)] async fn it_connects() -> anyhow::Result<()> { - let mut conn = connect().await?; + let mut conn = new::().await?; let value = sqlx::query("select 1 + 1") .try_map(|row: PgRow| row.try_get::(0)) @@ -21,7 +22,7 @@ async fn it_connects() -> anyhow::Result<()> { #[cfg_attr(feature = "runtime-async-std", async_std::test)] #[cfg_attr(feature = "runtime-tokio", tokio::test)] async fn it_executes() -> anyhow::Result<()> { - let mut conn = connect().await?; + let mut conn = new::().await?; let _ = conn .execute( @@ -55,7 +56,7 @@ CREATE TEMPORARY TABLE users (id INTEGER PRIMARY KEY); #[cfg_attr(feature = "runtime-async-std", async_std::test)] #[cfg_attr(feature = "runtime-tokio", tokio::test)] async fn it_can_return_interleaved_nulls_issue_104() -> anyhow::Result<()> { - let mut conn = connect().await?; + let mut conn = new::().await?; let tuple = sqlx::query("SELECT NULL::INT, 10::INT, NULL, 20::INT, NULL, 40::INT, NULL, 80::INT") @@ -89,7 +90,7 @@ async fn it_can_return_interleaved_nulls_issue_104() -> anyhow::Result<()> { #[cfg_attr(feature = "runtime-async-std", async_std::test)] #[cfg_attr(feature = "runtime-tokio", tokio::test)] async fn it_can_work_with_transactions() -> anyhow::Result<()> { - let mut conn = connect().await?; + let mut conn = new::().await?; conn.execute("CREATE TABLE IF NOT EXISTS _sqlx_users_1922 (id INTEGER PRIMARY KEY)") .await?; @@ -141,7 +142,7 @@ async fn it_can_work_with_transactions() -> anyhow::Result<()> { .await?; } - conn = connect().await?; + conn = new::().await?; let (count,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM _sqlx_users_1922") .fetch_one(&mut conn) @@ -210,7 +211,7 @@ async fn pool_smoke_test() -> anyhow::Result<()> { #[cfg_attr(feature = "runtime-async-std", async_std::test)] #[cfg_attr(feature = "runtime-tokio", tokio::test)] async fn test_describe() -> anyhow::Result<()> { - let mut conn = connect().await?; + let mut conn = new::().await?; let _ = conn .execute( @@ -239,10 +240,3 @@ async fn test_describe() -> anyhow::Result<()> { Ok(()) } - -async fn connect() -> anyhow::Result { - let _ = dotenv::dotenv(); - let _ = env_logger::try_init(); - - Ok(PgConnection::connect(dotenv::var("DATABASE_URL")?).await?) -}