mod enums; mod parse; mod shared; use parse::{Invocation, StructuredInput}; use proc_macro as pm; use proc_macro2::{self as pm2, Span}; use quote::{ToTokens, quote}; pub(crate) use shared::{ALL_OPERATIONS, FloatTy, MathOpInfo, Ty}; use syn::spanned::Spanned; use syn::visit_mut::VisitMut; use syn::{Ident, ItemEnum}; const KNOWN_TYPES: &[&str] = &["FTy", "CFn", "CArgs", "CRet", "RustFn", "RustArgs", "RustRet"]; /// Populate an enum with a variant representing function. Names are in upper camel case. /// /// Applied to an empty enum. Expects one attribute `#[function_enum(BaseName)]` that provides /// the name of the `BaseName` enum. #[proc_macro_attribute] pub fn function_enum(attributes: pm::TokenStream, tokens: pm::TokenStream) -> pm::TokenStream { let item = syn::parse_macro_input!(tokens as ItemEnum); let res = enums::function_enum(item, attributes.into()); match res { Ok(ts) => ts, Err(e) => e.into_compile_error(), } .into() } /// Create an enum representing all possible base names, with names in upper camel case. /// /// Applied to an empty enum. #[proc_macro_attribute] pub fn base_name_enum(attributes: pm::TokenStream, tokens: pm::TokenStream) -> pm::TokenStream { let item = syn::parse_macro_input!(tokens as ItemEnum); let res = enums::base_name_enum(item, attributes.into()); match res { Ok(ts) => ts, Err(e) => e.into_compile_error(), } .into() } /// Do something for each function present in this crate. /// /// Takes a callback macro and invokes it multiple times, once for each function that /// this crate exports. This makes it easy to create generic tests, benchmarks, or other checks /// and apply it to each symbol. /// /// Additionally, the `extra` and `fn_extra` patterns can make use of magic identifiers: /// /// - `MACRO_FN_NAME`: gets replaced with the name of the function on that invocation. /// - `MACRO_FN_NAME_NORMALIZED`: similar to the above, but removes sufixes so e.g. `sinf` becomes /// `sin`, `cosf128` becomes `cos`, etc. /// /// Invoke as: /// /// ``` /// // Macro that is invoked once per function /// macro_rules! callback_macro { /// ( /// // Name of that function /// fn_name: $fn_name:ident, /// // The basic float type for this function (e.g. `f32`, `f64`) /// FTy: $FTy:ty, /// // Function signature of the C version (e.g. `fn(f32, &mut f32) -> f32`) /// CFn: $CFn:ty, /// // A tuple representing the C version's arguments (e.g. `(f32, &mut f32)`) /// CArgs: $CArgs:ty, /// // The C version's return type (e.g. `f32`) /// CRet: $CRet:ty, /// // Function signature of the Rust version (e.g. `fn(f32) -> (f32, f32)`) /// RustFn: $RustFn:ty, /// // A tuple representing the Rust version's arguments (e.g. `(f32,)`) /// RustArgs: $RustArgs:ty, /// // The Rust version's return type (e.g. `(f32, f32)`) /// RustRet: $RustRet:ty, /// // Attributes for the current function, if any /// attrs: [$($attr:meta),*], /// // Extra tokens passed directly (if any) /// extra: [$extra:ident], /// // Extra function-tokens passed directly (if any) /// fn_extra: $fn_extra:expr, /// ) => { }; /// } /// /// // All fields except for `callback` are optional. /// libm_macros::for_each_function! { /// // The macro to invoke as a callback /// callback: callback_macro, /// // Which types to include either as a list (`[CFn, RustFn, RustArgs]`) or "all" /// emit_types: all, /// // Functions to skip, i.e. `callback` shouldn't be called at all for these. /// skip: [sin, cos], /// // Attributes passed as `attrs` for specific functions. For example, here the invocation /// // with `sinf` and that with `cosf` will both get `meta1` and `meta2`, but no others will. /// // /// // Note that `f16_enabled` and `f128_enabled` will always get emitted regardless of whether /// // or not this is specified. /// attributes: [ /// #[meta1] /// #[meta2] /// [sinf, cosf], /// ], /// // Any tokens that should be passed directly to all invocations of the callback. This can /// // be used to pass local variables or other things the macro needs access to. /// extra: [foo], /// // Similar to `extra`, but allow providing a pattern for only specific functions. Uses /// // a simplified match-like syntax. /// fn_extra: match MACRO_FN_NAME { /// hypot | hypotf => |x| x.hypot(), /// _ => |x| x, /// }, /// } /// ``` #[proc_macro] pub fn for_each_function(tokens: pm::TokenStream) -> pm::TokenStream { let input = syn::parse_macro_input!(tokens as Invocation); let res = StructuredInput::from_fields(input) .and_then(|mut s_in| validate(&mut s_in).map(|fn_list| (s_in, fn_list))) .and_then(|(s_in, fn_list)| expand(s_in, &fn_list)); match res { Ok(ts) => ts.into(), Err(e) => e.into_compile_error().into(), } } /// Check for any input that is structurally correct but has other problems. /// /// Returns the list of function names that we should expand for. fn validate(input: &mut StructuredInput) -> syn::Result> { // Collect lists of all functions that are provied as macro inputs in various fields (only, // skip, attributes). let attr_mentions = input .attributes .iter() .flat_map(|map_list| map_list.iter()) .flat_map(|attr_map| attr_map.names.iter()); let only_mentions = input.only.iter().flat_map(|only_list| only_list.iter()); let fn_extra_mentions = input.fn_extra.iter().flat_map(|v| v.keys()).filter(|name| *name != "_"); let all_mentioned_fns = input.skip.iter().chain(only_mentions).chain(attr_mentions).chain(fn_extra_mentions); // Make sure that every function mentioned is a real function for mentioned in all_mentioned_fns { if !ALL_OPERATIONS.iter().any(|func| mentioned == func.name) { let e = syn::Error::new( mentioned.span(), format!("unrecognized function name `{mentioned}`"), ); return Err(e); } } if !input.skip.is_empty() && input.only.is_some() { let e = syn::Error::new( input.only_span.unwrap(), "only one of `skip` or `only` may be specified", ); return Err(e); } // Construct a list of what we intend to expand let mut fn_list = Vec::new(); for func in ALL_OPERATIONS.iter() { let fn_name = func.name; // If we have an `only` list and it does _not_ contain this function name, skip it if input.only.as_ref().is_some_and(|only| !only.iter().any(|o| o == fn_name)) { continue; } // If there is a `skip` list that contains this function name, skip it if input.skip.iter().any(|s| s == fn_name) { continue; } // Run everything else fn_list.push(func); } // Types that the user would like us to provide in the macro let mut add_all_types = false; for ty in &input.emit_types { let ty_name = ty.to_string(); if ty_name == "all" { add_all_types = true; continue; } // Check that all requested types are valid if !KNOWN_TYPES.contains(&ty_name.as_str()) { let e = syn::Error::new( ty_name.span(), format!("unrecognized type identifier `{ty_name}`"), ); return Err(e); } } if add_all_types { // Ensure that if `all` was specified that nothing else was if input.emit_types.len() > 1 { let e = syn::Error::new( input.emit_types_span.unwrap(), "if `all` is specified, no other type identifiers may be given", ); return Err(e); } // ...and then add all types input.emit_types.clear(); for ty in KNOWN_TYPES { let ident = Ident::new(ty, Span::call_site()); input.emit_types.push(ident); } } if let Some(map) = &input.fn_extra { if !map.keys().any(|key| key == "_") { // No default provided; make sure every expected function is covered let mut fns_not_covered = Vec::new(); for func in &fn_list { if !map.keys().any(|key| key == func.name) { // `name` was not mentioned in the `match` statement fns_not_covered.push(func); } } if !fns_not_covered.is_empty() { let e = syn::Error::new( input.fn_extra_span.unwrap(), format!( "`fn_extra`: no default `_` pattern specified and the following \ patterns are not covered: {fns_not_covered:#?}" ), ); return Err(e); } } }; Ok(fn_list) } /// Expand our structured macro input into invocations of the callback macro. fn expand(input: StructuredInput, fn_list: &[&MathOpInfo]) -> syn::Result { let mut out = pm2::TokenStream::new(); let default_ident = Ident::new("_", Span::call_site()); let callback = input.callback; for func in fn_list { let fn_name = Ident::new(func.name, Span::call_site()); // Prepare attributes in an `attrs: ...` field let mut meta_fields = Vec::new(); if let Some(attrs) = &input.attributes { let meta_iter = attrs .iter() .filter(|map| map.names.contains(&fn_name)) .flat_map(|map| &map.meta) .map(|v| v.into_token_stream()); meta_fields.extend(meta_iter); } // Always emit f16 and f128 meta so this doesn't need to be repeated everywhere if func.rust_sig.args.contains(&Ty::F16) || func.rust_sig.returns.contains(&Ty::F16) { let ts = quote! { cfg(f16_enabled) }; meta_fields.push(ts); } if func.rust_sig.args.contains(&Ty::F128) || func.rust_sig.returns.contains(&Ty::F128) { let ts = quote! { cfg(f128_enabled) }; meta_fields.push(ts); } let meta_field = quote! { attrs: [ #( #meta_fields ),* ], }; // Prepare extra in an `extra: ...` field, running the replacer let extra_field = match input.extra.clone() { Some(mut extra) => { let mut v = MacroReplace::new(func.name); v.visit_expr_mut(&mut extra); v.finish()?; quote! { extra: #extra, } } None => pm2::TokenStream::new(), }; // Prepare function-specific extra in a `fn_extra: ...` field, running the replacer let fn_extra_field = match input.fn_extra { Some(ref map) => { let mut fn_extra = map.get(&fn_name).or_else(|| map.get(&default_ident)).unwrap().clone(); let mut v = MacroReplace::new(func.name); v.visit_expr_mut(&mut fn_extra); v.finish()?; quote! { fn_extra: #fn_extra, } } None => pm2::TokenStream::new(), }; let base_fty = func.float_ty; let c_args = &func.c_sig.args; let c_ret = &func.c_sig.returns; let rust_args = &func.rust_sig.args; let rust_ret = &func.rust_sig.returns; let mut ty_fields = Vec::new(); for ty in &input.emit_types { let field = match ty.to_string().as_str() { "FTy" => quote! { FTy: #base_fty, }, "CFn" => quote! { CFn: fn( #(#c_args),* ,) -> ( #(#c_ret),* ), }, "CArgs" => quote! { CArgs: ( #(#c_args),* ,), }, "CRet" => quote! { CRet: ( #(#c_ret),* ), }, "RustFn" => quote! { RustFn: fn( #(#rust_args),* ,) -> ( #(#rust_ret),* ), }, "RustArgs" => quote! { RustArgs: ( #(#rust_args),* ,), }, "RustRet" => quote! { RustRet: ( #(#rust_ret),* ), }, _ => unreachable!("checked in validation"), }; ty_fields.push(field); } let new = quote! { #callback! { fn_name: #fn_name, #( #ty_fields )* #meta_field #extra_field #fn_extra_field } }; out.extend(new); } Ok(out) } /// Visitor to replace "magic" identifiers that we allow: `MACRO_FN_NAME` and /// `MACRO_FN_NAME_NORMALIZED`. struct MacroReplace { fn_name: &'static str, /// Remove the trailing `f` or `f128` to make norm_name: String, error: Option, } impl MacroReplace { fn new(name: &'static str) -> Self { let norm_name = base_name(name); Self { fn_name: name, norm_name: norm_name.to_owned(), error: None } } fn finish(self) -> syn::Result<()> { match self.error { Some(e) => Err(e), None => Ok(()), } } fn visit_ident_inner(&mut self, i: &mut Ident) { let s = i.to_string(); if !s.starts_with("MACRO") || self.error.is_some() { return; } match s.as_str() { "MACRO_FN_NAME" => *i = Ident::new(self.fn_name, i.span()), "MACRO_FN_NAME_NORMALIZED" => *i = Ident::new(&self.norm_name, i.span()), _ => { self.error = Some(syn::Error::new(i.span(), format!("unrecognized meta expression `{s}`"))); } } } } impl VisitMut for MacroReplace { fn visit_ident_mut(&mut self, i: &mut Ident) { self.visit_ident_inner(i); syn::visit_mut::visit_ident_mut(self, i); } } /// Return the unsuffixed version of a function name; e.g. `abs` and `absf` both return `abs`, /// `lgamma_r` and `lgammaf_r` both return `lgamma_r`. fn base_name(name: &str) -> &str { let known_mappings = &[ ("erff", "erf"), ("erf", "erf"), ("lgammaf_r", "lgamma_r"), ("modff", "modf"), ("modf", "modf"), ]; match known_mappings.iter().find(|known| known.0 == name) { Some(found) => found.1, None => name .strip_suffix("f") .or_else(|| name.strip_suffix("f16")) .or_else(|| name.strip_suffix("f128")) .unwrap_or(name), } } impl ToTokens for Ty { fn to_tokens(&self, tokens: &mut pm2::TokenStream) { let ts = match self { Ty::F16 => quote! { f16 }, Ty::F32 => quote! { f32 }, Ty::F64 => quote! { f64 }, Ty::F128 => quote! { f128 }, Ty::I32 => quote! { i32 }, Ty::CInt => quote! { ::core::ffi::c_int }, Ty::MutF16 => quote! { &'a mut f16 }, Ty::MutF32 => quote! { &'a mut f32 }, Ty::MutF64 => quote! { &'a mut f64 }, Ty::MutF128 => quote! { &'a mut f128 }, Ty::MutI32 => quote! { &'a mut i32 }, Ty::MutCInt => quote! { &'a mut core::ffi::c_int }, }; tokens.extend(ts); } } impl ToTokens for FloatTy { fn to_tokens(&self, tokens: &mut pm2::TokenStream) { let ts = match self { FloatTy::F16 => quote! { f16 }, FloatTy::F32 => quote! { f32 }, FloatTy::F64 => quote! { f64 }, FloatTy::F128 => quote! { f128 }, }; tokens.extend(ts); } }