use super::attributes::{ check_strong_enum_attributes, check_struct_attributes, check_transparent_attributes, check_weak_enum_attributes, parse_attributes, }; use proc_macro2::Ident; use quote::quote; use syn::punctuated::Punctuated; use syn::token::Comma; use syn::{ parse_quote, Arm, Data, DataEnum, DataStruct, DeriveInput, Field, Fields, FieldsNamed, FieldsUnnamed, Stmt, Variant, }; pub fn expand_derive_decode(input: &DeriveInput) -> syn::Result { let attrs = parse_attributes(&input.attrs)?; match &input.data { Data::Struct(DataStruct { fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), .. }) if unnamed.len() == 1 => { expand_derive_decode_transparent(input, unnamed.first().unwrap()) } Data::Enum(DataEnum { variants, .. }) => match attrs.repr { Some(_) => expand_derive_decode_weak_enum(input, variants), None => expand_derive_decode_strong_enum(input, variants), }, Data::Struct(DataStruct { fields: Fields::Named(FieldsNamed { named, .. }), .. }) => expand_derive_decode_struct(input, named), _ => Err(syn::Error::new_spanned( input, "expected a tuple struct with a single field", )), } } fn expand_derive_decode_transparent( input: &DeriveInput, field: &Field, ) -> syn::Result { check_transparent_attributes(input, field)?; let ident = &input.ident; let ty = &field.ty; // extract type generics let generics = &input.generics; let (_, ty_generics, _) = generics.split_for_impl(); // add db type for impl generics & where clause let mut generics = generics.clone(); generics.params.insert(0, parse_quote!(DB: sqlx::Database)); generics .make_where_clause() .predicates .push(parse_quote!(#ty: sqlx::decode::Decode)); let (impl_generics, _, where_clause) = generics.split_for_impl(); Ok(quote!( impl #impl_generics sqlx::decode::Decode for #ident #ty_generics #where_clause { fn decode(raw: &[u8]) -> std::result::Result { <#ty as sqlx::decode::Decode>::decode(raw).map(Self) } fn decode_null() -> std::result::Result { <#ty as sqlx::decode::Decode>::decode_null().map(Self) } fn decode_nullable(raw: std::option::Option<&[u8]>) -> std::result::Result { <#ty as sqlx::decode::Decode>::decode_nullable(raw).map(Self) } } )) } fn expand_derive_decode_weak_enum( input: &DeriveInput, variants: &Punctuated, ) -> syn::Result { let repr = check_weak_enum_attributes(input, &variants)?; let ident = &input.ident; let arms = variants .iter() .map(|v| { let id = &v.ident; parse_quote!(_ if (#ident :: #id as #repr) == val => Ok(#ident :: #id),) }) .collect::>(); Ok(quote!( impl sqlx::decode::Decode for #ident where #repr: sqlx::decode::Decode { fn decode(raw: &[u8]) -> std::result::Result { let val = <#repr as sqlx::decode::Decode>::decode(raw)?; match val { #(#arms)* _ => Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid value"))) } } } )) } fn expand_derive_decode_strong_enum( input: &DeriveInput, variants: &Punctuated, ) -> syn::Result { check_strong_enum_attributes(input, &variants)?; let ident = &input.ident; let mut value_arms = Vec::new(); for v in variants { let id = &v.ident; let attributes = parse_attributes(&v.attrs)?; if let Some(rename) = attributes.rename { value_arms.push(quote!(#rename => Ok(#ident :: #id),)); } else { let name = id.to_string(); value_arms.push(quote!(#name => Ok(#ident :: #id),)); } } // TODO: prevent heap allocation Ok(quote!( impl sqlx::decode::Decode for #ident where std::string::String: sqlx::decode::Decode { fn decode(buf: &[u8]) -> std::result::Result { let val = >::decode(buf)?; match val.as_str() { #(#value_arms)* _ => Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid value"))) } } } )) } fn expand_derive_decode_struct( input: &DeriveInput, fields: &Punctuated, ) -> syn::Result { check_struct_attributes(input, fields)?; let mut tts = proc_macro2::TokenStream::new(); if cfg!(feature = "postgres") { let ident = &input.ident; let column_count = fields.len(); // extract type generics let generics = &input.generics; let (_, ty_generics, _) = generics.split_for_impl(); // add db type for impl generics & where clause let mut generics = generics.clone(); let predicates = &mut generics.make_where_clause().predicates; for field in fields { let ty = &field.ty; predicates.push(parse_quote!(#ty: sqlx::decode::Decode)); predicates.push(parse_quote!(sqlx::Postgres: sqlx::types::HasSqlType<#ty>)); } let (impl_generics, _, where_clause) = generics.split_for_impl(); let mut reads: Vec = Vec::new(); let mut names: Vec = Vec::new(); for field in fields { let id = &field.ident; names.push(id.clone().unwrap()); let ty = &field.ty; reads.push(parse_quote!( let #id = sqlx::postgres::decode_struct_field::<#ty>(&mut buf)?; )); } tts.extend(quote!( impl #impl_generics sqlx::decode::Decode for #ident#ty_generics #where_clause { fn decode(buf: &[u8]) -> std::result::Result { if buf.len() < 4 { return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Not enough data sent"))); } let column_count = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[..4]).unwrap()) as usize; if column_count != #column_count { return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid column count"))); } let mut buf = &buf[4..]; #(#reads)* if !buf.is_empty() { return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new(format!("Too much data sent ({} bytes left)", buf.len())))); } Ok(#ident { #(#names),* }) } } )) } Ok(tts) }