diff --git a/Cargo.toml b/Cargo.toml index e3cd32fa9..6e3443c6a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -92,6 +92,10 @@ required-features = [ "mysql" ] name = "mysql-raw" required-features = [ "mysql" ] +[[test]] +name = "mysql-derives" +required-features = [ "mysql", "macros" ] + [[test]] name = "postgres" required-features = [ "postgres" ] @@ -104,6 +108,10 @@ required-features = [ "postgres" ] name = "postgres-types" required-features = [ "postgres" ] +[[test]] +name = "postgres-derives" +required-features = [ "postgres", "macros" ] + [[test]] name = "mysql-types" required-features = [ "mysql" ] diff --git a/sqlx-macros/src/derives/attributes.rs b/sqlx-macros/src/derives/attributes.rs index 72df69910..a75c82a82 100644 --- a/sqlx-macros/src/derives/attributes.rs +++ b/sqlx-macros/src/derives/attributes.rs @@ -11,33 +11,42 @@ macro_rules! assert_attribute { }; } -pub struct SqlxAttributes { +macro_rules! fail { + ($t:expr, $m:expr) => { + return Err(syn::Error::new_spanned($t, $m)); + }; +} + +macro_rules! try_set { + ($i:ident, $v:expr, $t:expr) => { + match $i { + None => $i = Some($v), + Some(_) => fail!($t, "duplicate attribute"), + } + }; +} + +#[derive(Copy, Clone)] +pub enum RenameAll { + LowerCase, +} + +pub struct SqlxContainerAttributes { pub transparent: bool, pub postgres_oid: Option, + pub rename_all: Option, pub repr: Option, +} + +pub struct SqlxChildAttributes { pub rename: Option, } -pub fn parse_attributes(input: &[Attribute]) -> syn::Result { +pub fn parse_container_attributes(input: &[Attribute]) -> syn::Result { let mut transparent = None; let mut postgres_oid = None; let mut repr = None; - let mut rename = None; - - macro_rules! fail { - ($t:expr, $m:expr) => { - return Err(syn::Error::new_spanned($t, $m)); - }; - } - - macro_rules! try_set { - ($i:ident, $v:expr, $t:expr) => { - match $i { - None => $i = Some($v), - Some(_) => fail!($t, "duplicate attribute"), - } - }; - } + let mut rename_all = None; for attr in input { let meta = attr @@ -51,11 +60,21 @@ pub fn parse_attributes(input: &[Attribute]) -> syn::Result { Meta::Path(p) if p.is_ident("transparent") => { try_set!(transparent, true, value) } + Meta::NameValue(MetaNameValue { path, lit: Lit::Str(val), .. - }) if path.is_ident("rename") => try_set!(rename, val.value(), value), + }) if path.is_ident("rename_all") => { + let val = match &*val.value() { + "lowercase" => RenameAll::LowerCase, + + _ => fail!(meta, "unexpected value for rename_all"), + }; + + try_set!(rename_all, val, value) + }, + Meta::List(list) if list.path.is_ident("postgres") => { for value in list.nested.iter() { match value { @@ -92,85 +111,93 @@ pub fn parse_attributes(input: &[Attribute]) -> syn::Result { } } - Ok(SqlxAttributes { + Ok(SqlxContainerAttributes { transparent: transparent.unwrap_or(false), postgres_oid, repr, + rename_all, + }) +} + +pub fn parse_child_attributes(input: &[Attribute]) -> syn::Result { + let mut rename = None; + + for attr in input { + let meta = attr + .parse_meta() + .map_err(|e| syn::Error::new_spanned(attr, e))?; + + match meta { + Meta::List(list) if list.path.is_ident("sqlx") => { + for value in list.nested.iter() { + match value { + NestedMeta::Meta(meta) => match meta { + Meta::NameValue(MetaNameValue { + path, + lit: Lit::Str(val), + .. + }) if path.is_ident("rename") => try_set!(rename, val.value(), value), + + u => fail!(u, "unexpected attribute"), + }, + u => fail!(u, "unexpected attribute"), + } + } + } + _ => {} + } + } + + Ok(SqlxChildAttributes { rename, }) } pub fn check_transparent_attributes(input: &DeriveInput, field: &Field) -> syn::Result<()> { - let attributes = parse_attributes(&input.attrs)?; + let attributes = parse_container_attributes(&input.attrs)?; + assert_attribute!( attributes.transparent, "expected #[sqlx(transparent)]", input ); + #[cfg(feature = "postgres")] assert_attribute!( attributes.postgres_oid.is_none(), "unexpected #[sqlx(postgres(oid = ..))]", input ); + assert_attribute!( - attributes.rename.is_none(), - "unexpected #[sqlx(rename = ..)]", + attributes.rename_all.is_none(), + "unexpected #[sqlx(rename_all = ..)]", field ); + assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input); - let attributes = parse_attributes(&field.attrs)?; - assert_attribute!( - !attributes.transparent, - "unexpected #[sqlx(transparent)]", - field - ); - #[cfg(feature = "postgres")] - assert_attribute!( - attributes.postgres_oid.is_none(), - "unexpected #[sqlx(postgres(oid = ..))]", - field - ); + + let attributes = parse_child_attributes(&field.attrs)?; + assert_attribute!( attributes.rename.is_none(), "unexpected #[sqlx(rename = ..)]", field ); - assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", field); + Ok(()) } pub fn check_enum_attributes<'a>( input: &'a DeriveInput, - variants: &Punctuated, -) -> syn::Result { - let attributes = parse_attributes(&input.attrs)?; +) -> syn::Result { + let attributes = parse_container_attributes(&input.attrs)?; + assert_attribute!( !attributes.transparent, "unexpected #[sqlx(transparent)]", input ); - assert_attribute!( - attributes.rename.is_none(), - "unexpected #[sqlx(rename = ..)]", - input - ); - - for variant in variants { - let attributes = parse_attributes(&variant.attrs)?; - assert_attribute!( - !attributes.transparent, - "unexpected #[sqlx(transparent)]", - variant - ); - #[cfg(feature = "postgres")] - assert_attribute!( - attributes.postgres_oid.is_none(), - "unexpected #[sqlx(postgres(oid = ..))]", - variant - ); - assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", variant); - } Ok(attributes) } @@ -178,83 +205,90 @@ pub fn check_enum_attributes<'a>( pub fn check_weak_enum_attributes( input: &DeriveInput, variants: &Punctuated, -) -> syn::Result { - let attributes = check_enum_attributes(input, variants)?; +) -> syn::Result { + let attributes = check_enum_attributes(input)?; + #[cfg(feature = "postgres")] assert_attribute!( attributes.postgres_oid.is_none(), "unexpected #[sqlx(postgres(oid = ..))]", input ); + assert_attribute!(attributes.repr.is_some(), "expected #[repr(..)]", input); + + assert_attribute!( + attributes.rename_all.is_none(), + "unexpected #[sqlx(c = ..)]", + input + ); + for variant in variants { - let attributes = parse_attributes(&variant.attrs)?; + let attributes = parse_child_attributes(&variant.attrs)?; + assert_attribute!( attributes.rename.is_none(), "unexpected #[sqlx(rename = ..)]", variant ); } - Ok(attributes.repr.unwrap()) + + Ok(attributes) } pub fn check_strong_enum_attributes( input: &DeriveInput, - variants: &Punctuated, -) -> syn::Result { - let attributes = check_enum_attributes(input, variants)?; + _variants: &Punctuated, +) -> syn::Result { + let attributes = check_enum_attributes(input)?; + #[cfg(feature = "postgres")] assert_attribute!( attributes.postgres_oid.is_some(), "expected #[sqlx(postgres(oid = ..))]", input ); + assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input); + Ok(attributes) } pub fn check_struct_attributes<'a>( input: &'a DeriveInput, fields: &Punctuated, -) -> syn::Result { - let attributes = parse_attributes(&input.attrs)?; +) -> syn::Result { + let attributes = parse_container_attributes(&input.attrs)?; + assert_attribute!( !attributes.transparent, "unexpected #[sqlx(transparent)]", input ); + #[cfg(feature = "postgres")] assert_attribute!( attributes.postgres_oid.is_some(), "expected #[sqlx(postgres(oid = ..))]", input ); + assert_attribute!( - attributes.rename.is_none(), - "unexpected #[sqlx(rename = ..)]", + attributes.rename_all.is_none(), + "unexpected #[sqlx(rename_all = ..)]", input ); + assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input); for field in fields { - let attributes = parse_attributes(&field.attrs)?; - assert_attribute!( - !attributes.transparent, - "unexpected #[sqlx(transparent)]", - field - ); - #[cfg(feature = "postgres")] - assert_attribute!( - attributes.postgres_oid.is_none(), - "unexpected #[sqlx(postgres(oid = ..))]", - field - ); + let attributes = parse_child_attributes(&field.attrs)?; + assert_attribute!( attributes.rename.is_none(), "unexpected #[sqlx(rename = ..)]", field ); - assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", field); } Ok(attributes) diff --git a/sqlx-macros/src/derives/decode.rs b/sqlx-macros/src/derives/decode.rs index b6b818eea..9dae6fa33 100644 --- a/sqlx-macros/src/derives/decode.rs +++ b/sqlx-macros/src/derives/decode.rs @@ -1,7 +1,9 @@ use super::attributes::{ check_strong_enum_attributes, check_struct_attributes, check_transparent_attributes, - check_weak_enum_attributes, parse_attributes, + check_weak_enum_attributes, parse_container_attributes, + parse_child_attributes, }; +use super::rename_all; use quote::quote; use syn::punctuated::Punctuated; use syn::token::Comma; @@ -11,7 +13,7 @@ use syn::{ }; pub fn expand_derive_decode(input: &DeriveInput) -> syn::Result { - let attrs = parse_attributes(&input.attrs)?; + let attrs = parse_container_attributes(&input.attrs)?; match &input.data { Data::Struct(DataStruct { fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), @@ -83,24 +85,29 @@ fn expand_derive_decode_weak_enum( input: &DeriveInput, variants: &Punctuated, ) -> syn::Result { - let repr = check_weak_enum_attributes(input, &variants)?; + let attr = check_weak_enum_attributes(input, &variants)?; + let repr = attr.repr.unwrap(); let ident = &input.ident; + let ident_s = ident.to_string(); + let arms = variants .iter() .map(|v| { let id = &v.ident; - parse_quote!(_ if (#ident :: #id as #repr) == val => Ok(#ident :: #id),) + parse_quote!(_ if (#ident :: #id as #repr) == value => Ok(#ident :: #id),) }) .collect::>(); Ok(quote!( - impl sqlx::decode::Decode for #ident where #repr: sqlx::decode::Decode { - fn decode(raw: &[u8]) -> std::result::Result { - let val = <#repr as sqlx::decode::Decode>::decode(raw)?; - match val { + impl<'de, DB: sqlx::Database> sqlx::decode::Decode<'de, DB> for #ident where #repr: sqlx::decode::Decode<'de, DB> { + fn decode(value: >::RawValue) -> sqlx::Result { + let value = <#repr as sqlx::decode::Decode<'de, DB>>::decode(value)?; + + match value { #(#arms)* - _ => Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid value"))) + + _ => Err(sqlx::Error::Decode(format!("invalid value {:?} for enum {}", value, #ident_s).into())) } } } @@ -111,29 +118,35 @@ fn expand_derive_decode_strong_enum( input: &DeriveInput, variants: &Punctuated, ) -> syn::Result { - check_strong_enum_attributes(input, &variants)?; + let cattr = check_strong_enum_attributes(input, &variants)?; let ident = &input.ident; + let ident_s = ident.to_string(); let value_arms = variants.iter().map(|v| -> Arm { let id = &v.ident; - let attributes = parse_attributes(&v.attrs).unwrap(); + let attributes = parse_child_attributes(&v.attrs).unwrap(); + if let Some(rename) = attributes.rename { parse_quote!(#rename => Ok(#ident :: #id),) + } else if let Some(pattern) = cattr.rename_all { + let name = rename_all(&*id.to_string(), pattern); + + parse_quote!(#name => Ok(#ident :: #id),) } else { let name = id.to_string(); parse_quote!(#name => Ok(#ident :: #id),) } }); - // TODO: prevent heap allocation Ok(quote!( - impl sqlx::decode::Decode for #ident where std::string::String: sqlx::decode::Decode { - fn decode(buf: &[u8]) -> std::result::Result { - let val = >::decode(buf)?; - match val.as_str() { + impl<'de, DB: sqlx::Database> sqlx::decode::Decode<'de, DB> for #ident where &'de str: sqlx::decode::Decode<'de, DB> { + fn decode(value: >::RawValue) -> sqlx::Result { + let value = <&'de str as sqlx::decode::Decode<'de, DB>>::decode(value)?; + match value { #(#value_arms)* - _ => Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid value"))) + + _ => Err(sqlx::Error::Decode(format!("invalid value {:?} for enum {}", value, #ident_s).into())) } } } @@ -151,57 +164,50 @@ fn expand_derive_decode_struct( if cfg!(feature = "postgres") { let ident = &input.ident; - let column_count = fields.len(); - // extract type generics let generics = &input.generics; let (_, ty_generics, _) = generics.split_for_impl(); // add db type for impl generics & where clause let mut generics = generics.clone(); + generics.params.insert(0, parse_quote!('de)); + let predicates = &mut generics.make_where_clause().predicates; + for field in fields { let ty = &field.ty; - predicates.push(parse_quote!(#ty: sqlx::decode::Decode)); + + predicates.push(parse_quote!(#ty: sqlx::decode::Decode<'de, sqlx::Postgres>)); predicates.push(parse_quote!(#ty: sqlx::types::Type)); } + let (impl_generics, _, where_clause) = generics.split_for_impl(); let reads = fields.iter().map(|field| -> Stmt { let id = &field.ident; let ty = &field.ty; + parse_quote!( - let #id = sqlx::postgres::decode_struct_field::<#ty>(&mut buf)?; + let #id = decoder.decode::<#ty>()?; ) }); let names = fields.iter().map(|field| &field.ident); tts.extend(quote!( - impl #impl_generics sqlx::decode::Decode for #ident#ty_generics #where_clause { - fn decode(buf: &[u8]) -> std::result::Result { - if buf.len() < 4 { - return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Not enough data sent"))); + impl #impl_generics sqlx::decode::Decode<'de, sqlx::Postgres> for #ident #ty_generics #where_clause { + fn decode(value: >::RawValue) -> sqlx::Result { + let mut decoder = sqlx::postgres::types::PgRecordDecoder::new(value)?; + + #(#reads)* + + Ok(#ident { + #(#names),* + }) } - - let column_count = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[..4]).unwrap()) as usize; - if column_count != #column_count { - return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid column count"))); - } - let mut buf = &buf[4..]; - - #(#reads)* - - if !buf.is_empty() { - return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new(format!("Too much data sent ({} bytes left)", buf.len())))); - } - - Ok(#ident { - #(#names),* - }) } - } - )) + )); } + Ok(tts) } diff --git a/sqlx-macros/src/derives/encode.rs b/sqlx-macros/src/derives/encode.rs index 56ad9fb77..2de0064da 100644 --- a/sqlx-macros/src/derives/encode.rs +++ b/sqlx-macros/src/derives/encode.rs @@ -1,7 +1,8 @@ use super::attributes::{ check_strong_enum_attributes, check_struct_attributes, check_transparent_attributes, - check_weak_enum_attributes, parse_attributes, + check_weak_enum_attributes, parse_container_attributes, parse_child_attributes, }; +use super::rename_all; use quote::quote; use syn::punctuated::Punctuated; use syn::token::Comma; @@ -11,7 +12,7 @@ use syn::{ }; pub fn expand_derive_encode(input: &DeriveInput) -> syn::Result { - let args = parse_attributes(&input.attrs)?; + let args = parse_container_attributes(&input.attrs)?; match &input.data { Data::Struct(DataStruct { @@ -87,18 +88,21 @@ fn expand_derive_encode_weak_enum( input: &DeriveInput, variants: &Punctuated, ) -> syn::Result { - let repr = check_weak_enum_attributes(input, &variants)?; + let attr = check_weak_enum_attributes(input, &variants)?; + let repr = attr.repr.unwrap(); let ident = &input.ident; Ok(quote!( impl sqlx::encode::Encode for #ident where #repr: sqlx::encode::Encode { - fn encode(&self, buf: &mut std::vec::Vec) { + fn encode(&self, buf: &mut DB::RawBuffer) { sqlx::encode::Encode::encode(&(*self as #repr), buf) } - fn encode_nullable(&self, buf: &mut std::vec::Vec) -> sqlx::encode::IsNull { + + fn encode_nullable(&self, buf: &mut DB::RawBuffer) -> sqlx::encode::IsNull { sqlx::encode::Encode::encode_nullable(&(*self as #repr), buf) } + fn size_hint(&self) -> usize { sqlx::encode::Encode::size_hint(&(*self as #repr)) } @@ -110,16 +114,21 @@ fn expand_derive_encode_strong_enum( input: &DeriveInput, variants: &Punctuated, ) -> syn::Result { - check_strong_enum_attributes(input, &variants)?; + let cattr = check_strong_enum_attributes(input, &variants)?; let ident = &input.ident; let mut value_arms = Vec::new(); for v in variants { let id = &v.ident; - let attributes = parse_attributes(&v.attrs)?; + let attributes = parse_child_attributes(&v.attrs)?; + if let Some(rename) = attributes.rename { value_arms.push(quote!(#ident :: #id => #rename,)); + } else if let Some(pattern) = cattr.rename_all { + let name = rename_all(&*id.to_string(), pattern); + + value_arms.push(quote!(#ident :: #id => #name,)); } else { let name = id.to_string(); value_arms.push(quote!(#ident :: #id => #name,)); @@ -128,12 +137,13 @@ fn expand_derive_encode_strong_enum( Ok(quote!( impl sqlx::encode::Encode for #ident where str: sqlx::encode::Encode { - fn encode(&self, buf: &mut std::vec::Vec) { + fn encode(&self, buf: &mut DB::RawBuffer) { let val = match self { #(#value_arms)* }; >::encode(val, buf) } + fn size_hint(&self) -> usize { let val = match self { #(#value_arms)* @@ -154,7 +164,6 @@ fn expand_derive_encode_struct( if cfg!(feature = "postgres") { let ident = &input.ident; - let column_count = fields.len(); // extract type generics @@ -164,23 +173,29 @@ fn expand_derive_encode_struct( // add db type for impl generics & where clause let mut generics = generics.clone(); let predicates = &mut generics.make_where_clause().predicates; + for field in fields { let ty = &field.ty; + predicates.push(parse_quote!(#ty: sqlx::encode::Encode)); - predicates.push(parse_quote!(sqlx::Postgres: sqlx::types::HasSqlType<#ty>)); + predicates.push(parse_quote!(#ty: sqlx::types::Type)); } + let (impl_generics, _, where_clause) = generics.split_for_impl(); let writes = fields.iter().map(|field| -> Stmt { let id = &field.ident; + parse_quote!( - sqlx::postgres::encode_struct_field(buf, &self. #id); + // sqlx::postgres::encode_struct_field(buf, &self. #id); + encoder.encode(&self. #id); ) }); let sizes = fields.iter().map(|field| -> Expr { let id = &field.ident; let ty = &field.ty; + parse_quote!( <#ty as sqlx::encode::Encode>::size_hint(&self. #id) ) @@ -189,13 +204,16 @@ fn expand_derive_encode_struct( tts.extend(quote!( impl #impl_generics sqlx::encode::Encode for #ident #ty_generics #where_clause { fn encode(&self, buf: &mut std::vec::Vec) { - buf.extend(&(#column_count as u32).to_be_bytes()); + let mut encoder = sqlx::postgres::types::PgRecordEncoder::new(buf); + #(#writes)* + + encoder.finish(); } + fn size_hint(&self) -> usize { - 4 // oid - + #column_count * (4 + 4) // oid (int) and length (int) for each column - + #(#sizes)+* // sum of the size hints for each column + #column_count * (4 + 4) // oid (int) and length (int) for each column + + #(#sizes)+* // sum of the size hints for each column } } )); diff --git a/sqlx-macros/src/derives/mod.rs b/sqlx-macros/src/derives/mod.rs index 4e36533dd..888b737ea 100644 --- a/sqlx-macros/src/derives/mod.rs +++ b/sqlx-macros/src/derives/mod.rs @@ -7,6 +7,7 @@ pub(crate) use decode::expand_derive_decode; pub(crate) use encode::expand_derive_encode; pub(crate) use r#type::expand_derive_type; +use self::attributes::RenameAll; use std::iter::FromIterator; use syn::DeriveInput; @@ -23,3 +24,11 @@ pub(crate) fn expand_derive_type_encode_decode( Ok(combined) } + +pub(crate) fn rename_all(s: &str, pattern: RenameAll) -> String { + match pattern { + RenameAll::LowerCase => { + s.to_lowercase() + } + } +} diff --git a/sqlx-macros/src/derives/type.rs b/sqlx-macros/src/derives/type.rs index 0c86cb8a8..9fbec169b 100644 --- a/sqlx-macros/src/derives/type.rs +++ b/sqlx-macros/src/derives/type.rs @@ -1,6 +1,6 @@ use super::attributes::{ check_strong_enum_attributes, check_struct_attributes, check_transparent_attributes, - check_weak_enum_attributes, parse_attributes, + check_weak_enum_attributes, parse_container_attributes, }; use quote::quote; use syn::punctuated::Punctuated; @@ -11,7 +11,7 @@ use syn::{ }; pub fn expand_derive_type(input: &DeriveInput) -> syn::Result { - let attrs = parse_attributes(&input.attrs)?; + let attrs = parse_container_attributes(&input.attrs)?; match &input.data { Data::Struct(DataStruct { fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), @@ -65,64 +65,36 @@ fn expand_derive_has_sql_type_transparent( .make_where_clause() .predicates .push(parse_quote!(#ty: sqlx::types::Type)); + let (impl_generics, _, where_clause) = generics.split_for_impl(); - let mut tts = proc_macro2::TokenStream::new(); - - // if cfg!(feature = "mysql") { - tts.extend(quote!( + Ok(quote!( impl #impl_generics sqlx::types::Type< DB > for #ident #ty_generics #where_clause { fn type_info() -> DB::TypeInfo { <#ty as sqlx::Type>::type_info() } } - )); - - // } - - // if cfg!(feature = "postgres") { - // tts.extend(quote!( - // impl #impl_generics sqlx::types::HasSqlType< sqlx::Postgres > #ident #ty_generics #where_clause { - // fn type_info() -> Self::TypeInfo { - // >::type_info() - // } - // } - // )); - // } - - Ok(tts) + )) } fn expand_derive_has_sql_type_weak_enum( input: &DeriveInput, variants: &Punctuated, ) -> syn::Result { - let repr = check_weak_enum_attributes(input, variants)?; - + let attr = check_weak_enum_attributes(input, variants)?; + let repr = attr.repr.unwrap(); let ident = &input.ident; - let mut tts = proc_macro2::TokenStream::new(); - if cfg!(feature = "mysql") { - tts.extend(quote!( - impl sqlx::types::HasSqlType< #ident > for sqlx::MySql where Self: sqlx::types::HasSqlType< #repr > { - fn type_info() -> Self::TypeInfo { - >::type_info() - } + Ok(quote!( + impl sqlx::Type for #ident + where + #repr: sqlx::Type, + { + fn type_info() -> DB::TypeInfo { + <#repr as sqlx::Type>::type_info() } - )); - } - - if cfg!(feature = "postgres") { - tts.extend(quote!( - impl sqlx::types::HasSqlType< #ident > for sqlx::Postgres where Self: sqlx::types::HasSqlType< #repr > { - fn type_info() -> Self::TypeInfo { - >::type_info() - } - } - )); - } - - Ok(tts) + } + )) } fn expand_derive_has_sql_type_strong_enum( @@ -136,9 +108,11 @@ fn expand_derive_has_sql_type_strong_enum( if cfg!(feature = "mysql") { tts.extend(quote!( - impl sqlx::types::HasSqlType< #ident > for sqlx::MySql { - fn type_info() -> Self::TypeInfo { - sqlx::mysql::MySqlTypeInfo::r#enum() + impl sqlx::Type< sqlx::MySql > for #ident { + fn type_info() -> sqlx::mysql::MySqlTypeInfo { + // This is really fine, MySQL is loosely typed and + // we don't nede to be specific here + >::type_info() } } )); @@ -147,8 +121,8 @@ fn expand_derive_has_sql_type_strong_enum( if cfg!(feature = "postgres") { let oid = attributes.postgres_oid.unwrap(); tts.extend(quote!( - impl sqlx::types::HasSqlType< #ident > for sqlx::Postgres { - fn type_info() -> Self::TypeInfo { + impl sqlx::Type< sqlx::Postgres > for #ident { + fn type_info() -> sqlx::postgres::PgTypeInfo { sqlx::postgres::PgTypeInfo::with_oid(#oid) } } @@ -170,8 +144,8 @@ fn expand_derive_has_sql_type_struct( if cfg!(feature = "postgres") { let oid = attributes.postgres_oid.unwrap(); tts.extend(quote!( - impl sqlx::types::HasSqlType< #ident > for sqlx::Postgres { - fn type_info() -> Self::TypeInfo { + impl sqlx::types::Type< sqlx::Postgres > for #ident { + fn type_info() -> sqlx::postgres::PgTypeInfo { sqlx::postgres::PgTypeInfo::with_oid(#oid) } } diff --git a/tests/derives.rs b/tests/derives.rs deleted file mode 100644 index 5b5bd2894..000000000 --- a/tests/derives.rs +++ /dev/null @@ -1,318 +0,0 @@ -use sqlx::decode::Decode; -use sqlx::encode::Encode; -use sqlx::types::TypeInfo; -use sqlx::Type; -use std::fmt::Debug; - -#[derive(PartialEq, Debug, Type)] -#[sqlx(transparent)] -struct Transparent(i32); - -// #[derive(PartialEq, Debug, Clone, Copy, Encode, Decode, HasSqlType)] -// #[repr(i32)] -// #[allow(dead_code)] -// enum Weak { -// One, -// Two, -// Three, -// } -// -// #[derive(PartialEq, Debug, Encode, Decode, HasSqlType)] -// #[sqlx(postgres(oid = 10101010))] -// #[allow(dead_code)] -// enum Strong { -// One, -// Two, -// #[sqlx(rename = "four")] -// Three, -// } -// -// #[derive(PartialEq, Debug, Encode, Decode, HasSqlType)] -// #[sqlx(postgres(oid = 20202020))] -// #[allow(dead_code)] -// struct Struct { -// field1: String, -// field2: i64, -// field3: bool, -// } - -#[test] -#[cfg(feature = "mysql")] -fn encode_transparent_mysql() { - encode_transparent::(); -} - -#[test] -#[cfg(feature = "postgres")] -fn encode_transparent_postgres() { - encode_transparent::(); -} - -#[allow(dead_code)] -fn encode_transparent>>() -where - Transparent: Encode, - i32: Encode, -{ - let example = Transparent(0x1122_3344); - - let mut encoded = Vec::new(); - let mut encoded_orig = Vec::new(); - - Encode::::encode(&example, &mut encoded); - Encode::::encode(&example.0, &mut encoded_orig); - - assert_eq!(encoded, encoded_orig); -} -// -// #[test] -// #[cfg(feature = "mysql")] -// fn encode_weak_enum_mysql() { -// encode_weak_enum::(); -// } -// -// #[test] -// #[cfg(feature = "postgres")] -// fn encode_weak_enum_postgres() { -// encode_weak_enum::(); -// } -// -// #[allow(dead_code)] -// fn encode_weak_enum>>() -// where -// Weak: Encode, -// i32: Encode, -// { -// for example in [Weak::One, Weak::Two, Weak::Three].iter() { -// let mut encoded = Vec::new(); -// let mut encoded_orig = Vec::new(); -// -// Encode::::encode(example, &mut encoded); -// Encode::::encode(&(*example as i32), &mut encoded_orig); -// -// assert_eq!(encoded, encoded_orig); -// } -// } -// -// #[test] -// #[cfg(feature = "mysql")] -// fn encode_strong_enum_mysql() { -// encode_strong_enum::(); -// } -// -// #[test] -// #[cfg(feature = "postgres")] -// fn encode_strong_enum_postgres() { -// encode_strong_enum::(); -// } -// -// #[allow(dead_code)] -// fn encode_strong_enum>>() -// where -// Strong: Encode, -// str: Encode, -// { -// for (example, name) in [ -// (Strong::One, "One"), -// (Strong::Two, "Two"), -// (Strong::Three, "four"), -// ] -// .iter() -// { -// let mut encoded = Vec::new(); -// let mut encoded_orig = Vec::new(); -// -// Encode::::encode(example, &mut encoded); -// Encode::::encode(*name, &mut encoded_orig); -// -// assert_eq!(encoded, encoded_orig); -// } -// } -// -// #[test] -// #[cfg(feature = "postgres")] -// fn encode_struct_postgres() { -// let field1 = "Foo".to_string(); -// let field2 = 3; -// let field3 = false; -// -// let example = Struct { -// field1: field1.clone(), -// field2, -// field3, -// }; -// -// let mut encoded = Vec::new(); -// Encode::::encode(&example, &mut encoded); -// -// let string_oid = >::type_info().oid(); -// let i64_oid = >::type_info().oid(); -// let bool_oid = >::type_info().oid(); -// -// // 3 columns -// assert_eq!(&[0, 0, 0, 3], &encoded[..4]); -// let encoded = &encoded[4..]; -// -// // check field1 (string) -// assert_eq!(&string_oid.to_be_bytes(), &encoded[0..4]); -// assert_eq!(&(field1.len() as u32).to_be_bytes(), &encoded[4..8]); -// assert_eq!(field1.as_bytes(), &encoded[8..8 + field1.len()]); -// let encoded = &encoded[8 + field1.len()..]; -// -// // check field2 (i64) -// assert_eq!(&i64_oid.to_be_bytes(), &encoded[0..4]); -// assert_eq!(&8u32.to_be_bytes(), &encoded[4..8]); -// assert_eq!(field2.to_be_bytes(), &encoded[8..16]); -// let encoded = &encoded[16..]; -// -// // check field3 (bool) -// assert_eq!(&bool_oid.to_be_bytes(), &encoded[0..4]); -// assert_eq!(&1u32.to_be_bytes(), &encoded[4..8]); -// assert_eq!(field3, encoded[8] != 0); -// let encoded = &encoded[9..]; -// -// assert!(encoded.is_empty()); -// -// let string_size = >::size_hint(&field1); -// let i64_size = >::size_hint(&field2); -// let bool_size = >::size_hint(&field3); -// -// assert_eq!( -// 4 + 3 * (4 + 4) + string_size + i64_size + bool_size, -// example.size_hint() -// ); -// } - -#[test] -#[cfg(feature = "mysql")] -fn decode_transparent_mysql() { - decode_with_db::(Transparent(0x1122_3344)); -} - -#[test] -#[cfg(feature = "postgres")] -fn decode_transparent_postgres() { - decode_with_db::(Transparent(0x1122_3344)); -} -// -// #[test] -// #[cfg(feature = "mysql")] -// fn decode_weak_enum_mysql() { -// decode_with_db::(Weak::One); -// decode_with_db::(Weak::Two); -// decode_with_db::(Weak::Three); -// } -// -// #[test] -// #[cfg(feature = "postgres")] -// fn decode_weak_enum_postgres() { -// decode_with_db::(Weak::One); -// decode_with_db::(Weak::Two); -// decode_with_db::(Weak::Three); -// } -// -// #[test] -// #[cfg(feature = "mysql")] -// fn decode_strong_enum_mysql() { -// decode_with_db::(Strong::One); -// decode_with_db::(Strong::Two); -// decode_with_db::(Strong::Three); -// } -// -// #[test] -// #[cfg(feature = "postgres")] -// fn decode_strong_enum_postgres() { -// decode_with_db::(Strong::One); -// decode_with_db::(Strong::Two); -// decode_with_db::(Strong::Three); -// } -// -// #[test] -// #[cfg(feature = "postgres")] -// fn decode_struct_postgres() { -// decode_with_db::(Struct { -// field1: "Foo".to_string(), -// field2: 3, -// field3: true, -// }); -// } -// -#[allow(dead_code)] -fn decode_with_db< - DB: sqlx::Database>, - V: for<'de> Decode<'de, DB> + Encode + PartialEq + Debug, ->( - example: V, -) { - let mut encoded = Vec::new(); - Encode::::encode(&example, &mut encoded); - - // let decoded = V::decode(&encoded).unwrap(); - // assert_eq!(example, decoded); -} - -#[test] -#[cfg(feature = "mysql")] -fn type_transparent_mysql() { - type_transparent::(); -} - -#[test] -#[cfg(feature = "postgres")] -fn type_transparent_postgres() { - type_transparent::(); -} - -#[allow(dead_code)] -fn type_transparent>>() -where - Transparent: Type, - i32: Type, -{ - let info: DB::TypeInfo = >::type_info(); - let info_orig: DB::TypeInfo = >::type_info(); - assert!(info.compatible(&info_orig)); -} -// -// #[test] -// #[cfg(feature = "mysql")] -// fn type_weak_enum_mysql() { -// type_weak_enum::(); -// } -// -// #[test] -// #[cfg(feature = "postgres")] -// fn type_weak_enum_postgres() { -// type_weak_enum::(); -// } -// -// #[allow(dead_code)] -// fn type_weak_enum>>() -// where -// DB: HasSqlType + HasSqlType, -// { -// let info: DB::TypeInfo = >::type_info(); -// let info_orig: DB::TypeInfo = >::type_info(); -// assert!(info.compatible(&info_orig)); -// } -// -// #[test] -// #[cfg(feature = "mysql")] -// fn type_strong_enum_mysql() { -// let info: sqlx::mysql::MySqlTypeInfo = >::type_info(); -// assert!(info.compatible(&sqlx::mysql::MySqlTypeInfo::r#enum())) -// } -// -// #[test] -// #[cfg(feature = "postgres")] -// fn type_strong_enum_postgres() { -// let info: sqlx::postgres::PgTypeInfo = >::type_info(); -// assert!(info.compatible(&sqlx::postgres::PgTypeInfo::with_oid(10101010))) -// } -// -// #[test] -// #[cfg(feature = "postgres")] -// fn type_struct_postgres() { -// let info: sqlx::postgres::PgTypeInfo = >::type_info(); -// assert!(info.compatible(&sqlx::postgres::PgTypeInfo::with_oid(20202020))) -// } diff --git a/tests/mysql-derives.rs b/tests/mysql-derives.rs new file mode 100644 index 000000000..c92d06f94 --- /dev/null +++ b/tests/mysql-derives.rs @@ -0,0 +1,47 @@ +use sqlx_test::test_type; +use std::fmt::Debug; +use sqlx::MySql; + +// Transparent types are rust-side wrappers over DB types +#[derive(PartialEq, Debug, sqlx::Type)] +#[sqlx(transparent)] +struct Transparent(i32); + +// "Weak" enums map to an integer type indicated by #[repr] +#[derive(PartialEq, Copy, Clone, Debug, sqlx::Type)] +#[repr(i32)] +enum Weak { + One = 0, + Two = 2, + Three = 4, +} + +// "Strong" enums can map to TEXT or a custom enum +#[derive(PartialEq, Debug, sqlx::Type)] +#[sqlx(rename_all = "lowercase")] +enum Color { + Red, + Green, + Blue, +} + +test_type!(transparent( + MySql, + Transparent, + "0" == Transparent(0), + "23523" == Transparent(23523) +)); + +test_type!(weak_enum( + MySql, + Weak, + "0" == Weak::One, + "2" == Weak::Two, + "4" == Weak::Three +)); + +test_type!(strong_color_enum( + MySql, + Color, + "'green'" == Color::Green +)); diff --git a/tests/postgres-derives.rs b/tests/postgres-derives.rs new file mode 100644 index 000000000..0a3504d3a --- /dev/null +++ b/tests/postgres-derives.rs @@ -0,0 +1,81 @@ +use sqlx_test::test_type; +use std::fmt::Debug; +use sqlx::Postgres; + +// Transparent types are rust-side wrappers over DB types +#[derive(PartialEq, Debug, sqlx::Type)] +#[sqlx(transparent)] +struct Transparent(i32); + +// "Weak" enums map to an integer type indicated by #[repr] +#[derive(PartialEq, Copy, Clone, Debug, sqlx::Type)] +#[repr(i32)] +enum Weak { + One = 0, + Two = 2, + Three = 4, +} + +// "Strong" enums can map to TEXT (25) or a custom enum type +#[derive(PartialEq, Debug, sqlx::Type)] +#[sqlx(postgres(oid = 25))] +#[sqlx(rename_all = "lowercase")] +enum Strong { + One, + Two, + + #[sqlx(rename = "four")] + Three, +} + +// Records must map to a custom type +// Note that all types are types in Postgres +#[derive(PartialEq, Debug, sqlx::Type)] +#[sqlx(postgres(oid = 12184))] +struct PgConfig { + name: String, + setting: Option, +} + +test_type!(transparent( + Postgres, + Transparent, + "0" == Transparent(0), + "23523" == Transparent(23523) +)); + +test_type!(weak_enum( + Postgres, + Weak, + "0::int4" == Weak::One, + "2::int4" == Weak::Two, + "4::int4" == Weak::Three +)); + +test_type!(strong_enum( + Postgres, + Strong, + "'one'::text" == Strong::One, + "'two'::text" == Strong::Two, + "'four'::text" == Strong::Three +)); + +test_type!(record_pg_config( + Postgres, + PgConfig, + // (CC,gcc) + "(SELECT ROW('CC', 'gcc')::pg_config)" == PgConfig { + name: "CC".to_owned(), + setting: Some("gcc".to_owned()), + }, + // (CC,) + "(SELECT '(\"CC\",)'::pg_config)" == PgConfig { + name: "CC".to_owned(), + setting: None, + }, + // (CC,"") + "(SELECT '(\"CC\",\"\")'::pg_config)" == PgConfig { + name: "CC".to_owned(), + setting: Some("".to_owned()), + } +)); diff --git a/tests/postgres.rs b/tests/postgres.rs index 2eb1c9227..5d2b7d597 100644 --- a/tests/postgres.rs +++ b/tests/postgres.rs @@ -1,12 +1,13 @@ use futures::TryStreamExt; +use sqlx_test::new; use sqlx::postgres::{PgPool, PgQueryAs, PgRow}; -use sqlx::{postgres::PgConnection, Connect, Connection, Executor, Row}; +use sqlx::{Postgres, Connection, Executor, Row}; use std::time::Duration; #[cfg_attr(feature = "runtime-async-std", async_std::test)] #[cfg_attr(feature = "runtime-tokio", tokio::test)] async fn it_connects() -> anyhow::Result<()> { - let mut conn = connect().await?; + let mut conn = new::().await?; let value = sqlx::query("select 1 + 1") .try_map(|row: PgRow| row.try_get::(0)) @@ -21,7 +22,7 @@ async fn it_connects() -> anyhow::Result<()> { #[cfg_attr(feature = "runtime-async-std", async_std::test)] #[cfg_attr(feature = "runtime-tokio", tokio::test)] async fn it_executes() -> anyhow::Result<()> { - let mut conn = connect().await?; + let mut conn = new::().await?; let _ = conn .execute( @@ -55,7 +56,7 @@ CREATE TEMPORARY TABLE users (id INTEGER PRIMARY KEY); #[cfg_attr(feature = "runtime-async-std", async_std::test)] #[cfg_attr(feature = "runtime-tokio", tokio::test)] async fn it_can_return_interleaved_nulls_issue_104() -> anyhow::Result<()> { - let mut conn = connect().await?; + let mut conn = new::().await?; let tuple = sqlx::query("SELECT NULL::INT, 10::INT, NULL, 20::INT, NULL, 40::INT, NULL, 80::INT") @@ -89,7 +90,7 @@ async fn it_can_return_interleaved_nulls_issue_104() -> anyhow::Result<()> { #[cfg_attr(feature = "runtime-async-std", async_std::test)] #[cfg_attr(feature = "runtime-tokio", tokio::test)] async fn it_can_work_with_transactions() -> anyhow::Result<()> { - let mut conn = connect().await?; + let mut conn = new::().await?; conn.execute("CREATE TABLE IF NOT EXISTS _sqlx_users_1922 (id INTEGER PRIMARY KEY)") .await?; @@ -141,7 +142,7 @@ async fn it_can_work_with_transactions() -> anyhow::Result<()> { .await?; } - conn = connect().await?; + conn = new::().await?; let (count,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM _sqlx_users_1922") .fetch_one(&mut conn) @@ -210,7 +211,7 @@ async fn pool_smoke_test() -> anyhow::Result<()> { #[cfg_attr(feature = "runtime-async-std", async_std::test)] #[cfg_attr(feature = "runtime-tokio", tokio::test)] async fn test_describe() -> anyhow::Result<()> { - let mut conn = connect().await?; + let mut conn = new::().await?; let _ = conn .execute( @@ -239,10 +240,3 @@ async fn test_describe() -> anyhow::Result<()> { Ok(()) } - -async fn connect() -> anyhow::Result { - let _ = dotenv::dotenv(); - let _ = env_logger::try_init(); - - Ok(PgConnection::connect(dotenv::var("DATABASE_URL")?).await?) -}