//! This module contains the implementation of the `#[autodiff]` attribute. //! Currently our linter isn't smart enough to see that each import is used in one of the two //! configs (autodiff enabled or disabled), so we have to add cfg's to each import. //! FIXME(ZuseZ4): Remove this once we have a smarter linter. mod llvm_enzyme { use std::str::FromStr; use std::string::String; use rustc_ast::expand::autodiff_attrs::{ AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity, valid_ret_activity, valid_ty_for_activity, }; use rustc_ast::ptr::P; use rustc_ast::token::{Lit, LitKind, Token, TokenKind}; use rustc_ast::tokenstream::*; use rustc_ast::visit::AssocCtxt::*; use rustc_ast::{ self as ast, AssocItemKind, BindingMode, ExprKind, FnRetTy, FnSig, Generics, ItemKind, MetaItemInner, PatKind, QSelf, TyKind, Visibility, }; use rustc_expand::base::{Annotatable, ExtCtxt}; use rustc_span::{Ident, Span, Symbol, kw, sym}; use thin_vec::{ThinVec, thin_vec}; use tracing::{debug, trace}; use crate::errors; pub(crate) fn outer_normal_attr( kind: &P, id: rustc_ast::AttrId, span: Span, ) -> rustc_ast::Attribute { let style = rustc_ast::AttrStyle::Outer; let kind = rustc_ast::AttrKind::Normal(kind.clone()); rustc_ast::Attribute { kind, id, style, span } } // If we have a default `()` return type or explicitley `()` return type, // then we often can skip doing some work. fn has_ret(ty: &FnRetTy) -> bool { match ty { FnRetTy::Ty(ty) => !ty.kind.is_unit(), FnRetTy::Default(_) => false, } } fn first_ident(x: &MetaItemInner) -> rustc_span::Ident { if let Some(l) = x.lit() { match l.kind { ast::LitKind::Int(val, _) => { // get an Ident from a lit return rustc_span::Ident::from_str(val.get().to_string().as_str()); } _ => {} } } let segments = &x.meta_item().unwrap().path.segments; assert!(segments.len() == 1); segments[0].ident } fn name(x: &MetaItemInner) -> String { first_ident(x).name.to_string() } fn width(x: &MetaItemInner) -> Option { let lit = x.lit()?; match lit.kind { ast::LitKind::Int(x, _) => Some(x.get()), _ => return None, } } // Get information about the function the macro is applied to fn extract_item_info(iitem: &P) -> Option<(Visibility, FnSig, Ident, Generics)> { match &iitem.kind { ItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => { Some((iitem.vis.clone(), sig.clone(), ident.clone(), generics.clone())) } _ => None, } } pub(crate) fn from_ast( ecx: &mut ExtCtxt<'_>, meta_item: &ThinVec, has_ret: bool, mode: DiffMode, ) -> AutoDiffAttrs { let dcx = ecx.sess.dcx(); // Now we check, whether the user wants autodiff in batch/vector mode, or scalar mode. // If he doesn't specify an integer (=width), we default to scalar mode, thus width=1. let mut first_activity = 1; let width = if let [_, x, ..] = &meta_item[..] && let Some(x) = width(x) { first_activity = 2; match x.try_into() { Ok(x) => x, Err(_) => { dcx.emit_err(errors::AutoDiffInvalidWidth { span: meta_item[1].span(), width: x, }); return AutoDiffAttrs::error(); } } } else { 1 }; let mut activities: Vec = vec![]; let mut errors = false; for x in &meta_item[first_activity..] { let activity_str = name(&x); let res = DiffActivity::from_str(&activity_str); match res { Ok(x) => activities.push(x), Err(_) => { dcx.emit_err(errors::AutoDiffUnknownActivity { span: x.span(), act: activity_str, }); errors = true; } }; } if errors { return AutoDiffAttrs::error(); } // If a return type exist, we need to split the last activity, // otherwise we return None as placeholder. let (ret_activity, input_activity) = if has_ret { let Some((last, rest)) = activities.split_last() else { unreachable!( "should not be reachable because we counted the number of activities previously" ); }; (last, rest) } else { (&DiffActivity::None, activities.as_slice()) }; AutoDiffAttrs { mode, width, ret_activity: *ret_activity, input_activity: input_activity.to_vec(), } } fn meta_item_inner_to_ts(t: &MetaItemInner, ts: &mut Vec) { let comma: Token = Token::new(TokenKind::Comma, Span::default()); let val = first_ident(t); let t = Token::from_ast_ident(val); ts.push(TokenTree::Token(t, Spacing::Joint)); ts.push(TokenTree::Token(comma.clone(), Spacing::Alone)); } pub(crate) fn expand_forward( ecx: &mut ExtCtxt<'_>, expand_span: Span, meta_item: &ast::MetaItem, item: Annotatable, ) -> Vec { expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Forward) } pub(crate) fn expand_reverse( ecx: &mut ExtCtxt<'_>, expand_span: Span, meta_item: &ast::MetaItem, item: Annotatable, ) -> Vec { expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Reverse) } /// We expand the autodiff macro to generate a new placeholder function which passes /// type-checking and can be called by users. The function body of the placeholder function will /// later be replaced on LLVM-IR level, so the design of the body is less important and for now /// should just prevent early inlining and optimizations which alter the function signature. /// The exact signature of the generated function depends on the configuration provided by the /// user, but here is an example: /// /// ``` /// #[autodiff(cos_box, Reverse, Duplicated, Active)] /// fn sin(x: &Box) -> f32 { /// f32::sin(**x) /// } /// ``` /// which becomes expanded to: /// ``` /// #[rustc_autodiff] /// #[inline(never)] /// fn sin(x: &Box) -> f32 { /// f32::sin(**x) /// } /// #[rustc_autodiff(Reverse, Duplicated, Active)] /// #[inline(never)] /// fn cos_box(x: &Box, dx: &mut Box, dret: f32) -> f32 { /// unsafe { /// asm!("NOP"); /// }; /// ::core::hint::black_box(sin(x)); /// ::core::hint::black_box((dx, dret)); /// ::core::hint::black_box(sin(x)) /// } /// ``` /// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked /// in CI. pub(crate) fn expand_with_mode( ecx: &mut ExtCtxt<'_>, expand_span: Span, meta_item: &ast::MetaItem, mut item: Annotatable, mode: DiffMode, ) -> Vec { if cfg!(not(llvm_enzyme)) { ecx.sess.dcx().emit_err(errors::AutoDiffSupportNotBuild { span: meta_item.span }); return vec![item]; } let dcx = ecx.sess.dcx(); // first get information about the annotable item: visibility, signature, name and generic // parameters. // these will be used to generate the differentiated version of the function let Some((vis, sig, primal, generics)) = (match &item { Annotatable::Item(iitem) => extract_item_info(iitem), Annotatable::Stmt(stmt) => match &stmt.kind { ast::StmtKind::Item(iitem) => extract_item_info(iitem), _ => None, }, Annotatable::AssocItem(assoc_item, Impl { .. }) => match &assoc_item.kind { ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => { Some((assoc_item.vis.clone(), sig.clone(), ident.clone(), generics.clone())) } _ => None, }, _ => None, }) else { dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() }); return vec![item]; }; let meta_item_vec: ThinVec = match meta_item.kind { ast::MetaItemKind::List(ref vec) => vec.clone(), _ => { dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() }); return vec![item]; } }; let has_ret = has_ret(&sig.decl.output); let sig_span = ecx.with_call_site_ctxt(sig.span); // create TokenStream from vec elemtents: // meta_item doesn't have a .tokens field let mut ts: Vec = vec![]; if meta_item_vec.len() < 1 { // At the bare minimum, we need a fnc name. dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() }); return vec![item]; } let mode_symbol = match mode { DiffMode::Forward => sym::Forward, DiffMode::Reverse => sym::Reverse, _ => unreachable!("Unsupported mode: {:?}", mode), }; // Insert mode token let mode_token = Token::new(TokenKind::Ident(mode_symbol, false.into()), Span::default()); ts.insert(0, TokenTree::Token(mode_token, Spacing::Joint)); ts.insert( 1, TokenTree::Token(Token::new(TokenKind::Comma, Span::default()), Spacing::Alone), ); // Now, if the user gave a width (vector aka batch-mode ad), then we copy it. // If it is not given, we default to 1 (scalar mode). let start_position; let kind: LitKind = LitKind::Integer; let symbol; if meta_item_vec.len() >= 2 && let Some(width) = width(&meta_item_vec[1]) { start_position = 2; symbol = Symbol::intern(&width.to_string()); } else { start_position = 1; symbol = sym::integer(1); } let l: Lit = Lit { kind, symbol, suffix: None }; let t = Token::new(TokenKind::Literal(l), Span::default()); let comma = Token::new(TokenKind::Comma, Span::default()); ts.push(TokenTree::Token(t, Spacing::Joint)); ts.push(TokenTree::Token(comma.clone(), Spacing::Alone)); for t in meta_item_vec.clone()[start_position..].iter() { meta_item_inner_to_ts(t, &mut ts); } if !has_ret { // We don't want users to provide a return activity if the function doesn't return anything. // For simplicity, we just add a dummy token to the end of the list. let t = Token::new(TokenKind::Ident(sym::None, false.into()), Span::default()); ts.push(TokenTree::Token(t, Spacing::Joint)); ts.push(TokenTree::Token(comma, Spacing::Alone)); } // We remove the last, trailing comma. ts.pop(); let ts: TokenStream = TokenStream::from_iter(ts); let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret, mode); if !x.is_active() { // We encountered an error, so we return the original item. // This allows us to potentially parse other attributes. return vec![item]; } let span = ecx.with_def_site_ctxt(expand_span); let n_active: u32 = x .input_activity .iter() .filter(|a| **a == DiffActivity::Active || **a == DiffActivity::ActiveOnly) .count() as u32; let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span); let d_body = gen_enzyme_body( ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored, &generics, ); // The first element of it is the name of the function to be generated let asdf = Box::new(ast::Fn { defaultness: ast::Defaultness::Final, sig: d_sig, ident: first_ident(&meta_item_vec[0]), generics, contract: None, body: Some(d_body), define_opaque: None, }); let mut rustc_ad_attr = P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_autodiff))); let ts2: Vec = vec![TokenTree::Token( Token::new(TokenKind::Ident(sym::never, false.into()), span), Spacing::Joint, )]; let never_arg = ast::DelimArgs { dspan: DelimSpan::from_single(span), delim: ast::token::Delimiter::Parenthesis, tokens: TokenStream::from_iter(ts2), }; let inline_item = ast::AttrItem { unsafety: ast::Safety::Default, path: ast::Path::from_ident(Ident::with_dummy_span(sym::inline)), args: ast::AttrArgs::Delimited(never_arg), tokens: None, }; let inline_never_attr = P(ast::NormalAttr { item: inline_item, tokens: None }); let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id(); let attr = outer_normal_attr(&rustc_ad_attr, new_id, span); let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id(); let inline_never = outer_normal_attr(&inline_never_attr, new_id, span); // We're avoid duplicating the attributes `#[rustc_autodiff]` and `#[inline(never)]`. fn same_attribute(attr: &ast::AttrKind, item: &ast::AttrKind) -> bool { match (attr, item) { (ast::AttrKind::Normal(a), ast::AttrKind::Normal(b)) => { let a = &a.item.path; let b = &b.item.path; a.segments.len() == b.segments.len() && a.segments.iter().zip(b.segments.iter()).all(|(a, b)| a.ident == b.ident) } _ => false, } } // Don't add it multiple times: let orig_annotatable: Annotatable = match item { Annotatable::Item(ref mut iitem) => { if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) { iitem.attrs.push(attr); } if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) { iitem.attrs.push(inline_never.clone()); } Annotatable::Item(iitem.clone()) } Annotatable::AssocItem(ref mut assoc_item, i @ Impl { .. }) => { if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) { assoc_item.attrs.push(attr); } if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) { assoc_item.attrs.push(inline_never.clone()); } Annotatable::AssocItem(assoc_item.clone(), i) } Annotatable::Stmt(ref mut stmt) => { match stmt.kind { ast::StmtKind::Item(ref mut iitem) => { if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) { iitem.attrs.push(attr); } if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) { iitem.attrs.push(inline_never.clone()); } } _ => unreachable!("stmt kind checked previously"), }; Annotatable::Stmt(stmt.clone()) } _ => { unreachable!("annotatable kind checked previously") } }; // Now update for d_fn rustc_ad_attr.item.args = rustc_ast::AttrArgs::Delimited(rustc_ast::DelimArgs { dspan: DelimSpan::dummy(), delim: rustc_ast::token::Delimiter::Parenthesis, tokens: ts, }); let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span); let d_annotatable = match &item { Annotatable::AssocItem(_, _) => { let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf); let d_fn = P(ast::AssocItem { attrs: thin_vec![d_attr, inline_never], id: ast::DUMMY_NODE_ID, span, vis, kind: assoc_item, tokens: None, }); Annotatable::AssocItem(d_fn, Impl { of_trait: false }) } Annotatable::Item(_) => { let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf)); d_fn.vis = vis; Annotatable::Item(d_fn) } Annotatable::Stmt(_) => { let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf)); d_fn.vis = vis; Annotatable::Stmt(P(ast::Stmt { id: ast::DUMMY_NODE_ID, kind: ast::StmtKind::Item(d_fn), span, })) } _ => { unreachable!("item kind checked previously") } }; return vec![orig_annotatable, d_annotatable]; } // shadow arguments (the extra ones which were not in the original (primal) function), in reverse mode must be // mutable references or ptrs, because Enzyme will write into them. fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty { let mut ty = ty.clone(); match ty.kind { TyKind::Ptr(ref mut mut_ty) => { mut_ty.mutbl = ast::Mutability::Mut; } TyKind::Ref(_, ref mut mut_ty) => { mut_ty.mutbl = ast::Mutability::Mut; } _ => { panic!("unsupported type: {:?}", ty); } } ty } // Will generate a body of the type: // ``` // { // unsafe { // asm!("NOP"); // } // ::core::hint::black_box(primal(args)); // ::core::hint::black_box((args, ret)); // // } // ``` fn init_body_helper( ecx: &ExtCtxt<'_>, span: Span, primal: Ident, new_names: &[String], sig_span: Span, new_decl_span: Span, idents: &[Ident], errored: bool, generics: &Generics, ) -> (P, P, P, P) { let blackbox_path = ecx.std_path(&[sym::hint, sym::black_box]); let noop = ast::InlineAsm { asm_macro: ast::AsmMacro::Asm, template: vec![ast::InlineAsmTemplatePiece::String("NOP".into())], template_strs: Box::new([]), operands: vec![], clobber_abis: vec![], options: ast::InlineAsmOptions::PURE | ast::InlineAsmOptions::NOMEM, line_spans: vec![], }; let noop_expr = ecx.expr_asm(span, P(noop)); let unsf = ast::BlockCheckMode::Unsafe(ast::UnsafeSource::CompilerGenerated); let unsf_block = ast::Block { stmts: thin_vec![ecx.stmt_semi(noop_expr)], id: ast::DUMMY_NODE_ID, tokens: None, rules: unsf, span, }; let unsf_expr = ecx.expr_block(P(unsf_block)); let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path)); let primal_call = gen_primal_call(ecx, span, primal, idents, generics); let black_box_primal_call = ecx.expr_call( new_decl_span, blackbox_call_expr.clone(), thin_vec![primal_call.clone()], ); let tup_args = new_names .iter() .map(|arg| ecx.expr_path(ecx.path_ident(span, Ident::from_str(arg)))) .collect(); let black_box_remaining_args = ecx.expr_call( sig_span, blackbox_call_expr.clone(), thin_vec![ecx.expr_tuple(sig_span, tup_args)], ); let mut body = ecx.block(span, ThinVec::new()); body.stmts.push(ecx.stmt_semi(unsf_expr)); // This uses primal args which won't be available if we errored before if !errored { body.stmts.push(ecx.stmt_semi(black_box_primal_call.clone())); } body.stmts.push(ecx.stmt_semi(black_box_remaining_args)); (body, primal_call, black_box_primal_call, blackbox_call_expr) } /// We only want this function to type-check, since we will replace the body /// later on llvm level. Using `loop {}` does not cover all return types anymore, /// so instead we manually build something that should pass the type checker. /// We also add a inline_asm line, as one more barrier for rustc to prevent inlining /// or const propagation. inline_asm will also triggers an Enzyme crash if due to another /// bug would ever try to accidentally differentiate this placeholder function body. /// Finally, we also add back_box usages of all input arguments, to prevent rustc /// from optimizing any arguments away. fn gen_enzyme_body( ecx: &ExtCtxt<'_>, x: &AutoDiffAttrs, n_active: u32, sig: &ast::FnSig, d_sig: &ast::FnSig, primal: Ident, new_names: &[String], span: Span, sig_span: Span, idents: Vec, errored: bool, generics: &Generics, ) -> P { let new_decl_span = d_sig.span; // Just adding some default inline-asm and black_box usages to prevent early inlining // and optimizations which alter the function signature. // // The bb_primal_call is the black_box call of the primal function. We keep it around, // since it has the convenient property of returning the type of the primal function, // Remember, we only care to match types here. // No matter which return we pick, we always wrap it into a std::hint::black_box call, // to prevent rustc from propagating it into the caller. let (mut body, primal_call, bb_primal_call, bb_call_expr) = init_body_helper( ecx, span, primal, new_names, sig_span, new_decl_span, &idents, errored, generics, ); if !has_ret(&d_sig.decl.output) { // there is no return type that we have to match, () works fine. return body; } // Everything from here onwards just tries to fulfil the return type. Fun! // having an active-only return means we'll drop the original return type. // So that can be treated identical to not having one in the first place. let primal_ret = has_ret(&sig.decl.output) && !x.has_active_only_ret(); if primal_ret && n_active == 0 && x.mode.is_rev() { // We only have the primal ret. body.stmts.push(ecx.stmt_expr(bb_primal_call)); return body; } if !primal_ret && n_active == 1 { // Again no tuple return, so return default float val. let ty = match d_sig.decl.output { FnRetTy::Ty(ref ty) => ty.clone(), FnRetTy::Default(span) => { panic!("Did not expect Default ret ty: {:?}", span); } }; let arg = ty.kind.is_simple_path().unwrap(); let tmp = ecx.def_site_path(&[arg, kw::Default]); let default_call_expr = ecx.expr_path(ecx.path(span, tmp)); let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]); body.stmts.push(ecx.stmt_expr(default_call_expr)); return body; } let mut exprs: P = primal_call; let d_ret_ty = match d_sig.decl.output { FnRetTy::Ty(ref ty) => ty.clone(), FnRetTy::Default(span) => { panic!("Did not expect Default ret ty: {:?}", span); } }; if x.mode.is_fwd() { // Fwd mode is easy. If the return activity is Const, we support arbitrary types. // Otherwise, we only support a scalar, a pair of scalars, or an array of scalars. // We checked that (on a best-effort base) in the preceding gen_enzyme_decl function. // In all three cases, we can return `std::hint::black_box(::default())`. if x.ret_activity == DiffActivity::Const { // Here we call the primal function, since our dummy function has the same return // type due to the Const return activity. exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![exprs]); } else { let q = QSelf { ty: d_ret_ty, path_span: span, position: 0 }; let y = ExprKind::Path( Some(P(q)), ecx.path_ident(span, Ident::with_dummy_span(kw::Default)), ); let default_call_expr = ecx.expr(span, y); let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]); exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![default_call_expr]); } } else if x.mode.is_rev() { if x.width == 1 { // We either have `-> ArbitraryType` or `-> (ArbitraryType, repeated_float_scalars)`. match d_ret_ty.kind { TyKind::Tup(ref args) => { // We have a tuple return type. We need to create a tuple of the same size // and fill it with default values. let mut exprs2 = thin_vec![exprs]; for arg in args.iter().skip(1) { let arg = arg.kind.is_simple_path().unwrap(); let tmp = ecx.def_site_path(&[arg, kw::Default]); let default_call_expr = ecx.expr_path(ecx.path(span, tmp)); let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]); exprs2.push(default_call_expr); } exprs = ecx.expr_tuple(new_decl_span, exprs2); } _ => { // Interestingly, even the `-> ArbitraryType` case // ends up getting matched and handled correctly above, // so we don't have to handle any other case for now. panic!("Unsupported return type: {:?}", d_ret_ty); } } } exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![exprs]); } else { unreachable!("Unsupported mode: {:?}", x.mode); } body.stmts.push(ecx.stmt_expr(exprs)); body } fn gen_primal_call( ecx: &ExtCtxt<'_>, span: Span, primal: Ident, idents: &[Ident], generics: &Generics, ) -> P { let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower; if has_self { let args: ThinVec<_> = idents[1..].iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect(); let self_expr = ecx.expr_self(span); ecx.expr_method_call(span, self_expr, primal, args) } else { let args: ThinVec<_> = idents.iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect(); let mut primal_path = ecx.path_ident(span, primal); let is_generic = !generics.params.is_empty(); match (is_generic, primal_path.segments.last_mut()) { (true, Some(function_path)) => { let primal_generic_types = generics .params .iter() .filter(|param| matches!(param.kind, ast::GenericParamKind::Type { .. })); let generated_generic_types = primal_generic_types .map(|type_param| { let generic_param = TyKind::Path( None, ast::Path { span, segments: thin_vec![ast::PathSegment { ident: type_param.ident, args: None, id: ast::DUMMY_NODE_ID, }], tokens: None, }, ); ast::AngleBracketedArg::Arg(ast::GenericArg::Type(P(ast::Ty { id: type_param.id, span, kind: generic_param, tokens: None, }))) }) .collect(); function_path.args = Some(P(ast::GenericArgs::AngleBracketed(ast::AngleBracketedArgs { span, args: generated_generic_types, }))); } _ => {} } let primal_call_expr = ecx.expr_path(primal_path); ecx.expr_call(span, primal_call_expr, args) } } // Generate the new function declaration. Const arguments are kept as is. Duplicated arguments must // be pointers or references. Those receive a shadow argument, which is a mutable reference/pointer. // Active arguments must be scalars. Their shadow argument is added to the return type (and will be // zero-initialized by Enzyme). // Each argument of the primal function (and the return type if existing) must be annotated with an // activity. // // Error handling: If the user provides an invalid configuration (incorrect numbers, types, or // both), we emit an error and return the original signature. This allows us to continue parsing. // FIXME(Sa4dUs): make individual activities' span available so errors // can point to only the activity instead of the entire attribute fn gen_enzyme_decl( ecx: &ExtCtxt<'_>, sig: &ast::FnSig, x: &AutoDiffAttrs, span: Span, ) -> (ast::FnSig, Vec, Vec, bool) { let dcx = ecx.sess.dcx(); let has_ret = has_ret(&sig.decl.output); let sig_args = sig.decl.inputs.len() + if has_ret { 1 } else { 0 }; let num_activities = x.input_activity.len() + if x.has_ret_activity() { 1 } else { 0 }; if sig_args != num_activities { dcx.emit_err(errors::AutoDiffInvalidNumberActivities { span, expected: sig_args, found: num_activities, }); // This is not the right signature, but we can continue parsing. return (sig.clone(), vec![], vec![], true); } assert!(sig.decl.inputs.len() == x.input_activity.len()); assert!(has_ret == x.has_ret_activity()); let mut d_decl = sig.decl.clone(); let mut d_inputs = Vec::new(); let mut new_inputs = Vec::new(); let mut idents = Vec::new(); let mut act_ret = ThinVec::new(); // We have two loops, a first one just to check the activities and types and possibly report // multiple errors in one compilation session. let mut errors = false; for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) { if !valid_input_activity(x.mode, *activity) { dcx.emit_err(errors::AutoDiffInvalidApplicationModeAct { span, mode: x.mode.to_string(), act: activity.to_string(), }); errors = true; } if !valid_ty_for_activity(&arg.ty, *activity) { dcx.emit_err(errors::AutoDiffInvalidTypeForActivity { span: arg.ty.span, act: activity.to_string(), }); errors = true; } } if has_ret && !valid_ret_activity(x.mode, x.ret_activity) { dcx.emit_err(errors::AutoDiffInvalidRetAct { span, mode: x.mode.to_string(), act: x.ret_activity.to_string(), }); // We don't set `errors = true` to avoid annoying type errors relative // to the expanded macro type signature } if errors { // This is not the right signature, but we can continue parsing. return (sig.clone(), new_inputs, idents, true); } let unsafe_activities = x .input_activity .iter() .any(|&act| matches!(act, DiffActivity::DuplicatedOnly | DiffActivity::DualOnly)); for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) { d_inputs.push(arg.clone()); match activity { DiffActivity::Active => { act_ret.push(arg.ty.clone()); // if width =/= 1, then push [arg.ty; width] to act_ret } DiffActivity::ActiveOnly => { // We will add the active scalar to the return type. // This is handled later. } DiffActivity::Duplicated | DiffActivity::DuplicatedOnly => { for i in 0..x.width { let mut shadow_arg = arg.clone(); // We += into the shadow in reverse mode. shadow_arg.ty = P(assure_mut_ref(&arg.ty)); let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind { ident.name } else { debug!("{:#?}", &shadow_arg.pat); panic!("not an ident?"); }; let name: String = format!("d{}_{}", old_name, i); new_inputs.push(name.clone()); let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span); shadow_arg.pat = P(ast::Pat { id: ast::DUMMY_NODE_ID, kind: PatKind::Ident(BindingMode::NONE, ident, None), span: shadow_arg.pat.span, tokens: shadow_arg.pat.tokens.clone(), }); d_inputs.push(shadow_arg.clone()); } } DiffActivity::Dual | DiffActivity::DualOnly | DiffActivity::Dualv | DiffActivity::DualvOnly => { // the *v variants get lowered to enzyme_dupv and enzyme_dupnoneedv, which cause // Enzyme to not expect N arguments, but one argument (which is instead larger). let iterations = if matches!(activity, DiffActivity::Dualv | DiffActivity::DualvOnly) { 1 } else { x.width }; for i in 0..iterations { let mut shadow_arg = arg.clone(); let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind { ident.name } else { debug!("{:#?}", &shadow_arg.pat); panic!("not an ident?"); }; let name: String = format!("b{}_{}", old_name, i); new_inputs.push(name.clone()); let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span); shadow_arg.pat = P(ast::Pat { id: ast::DUMMY_NODE_ID, kind: PatKind::Ident(BindingMode::NONE, ident, None), span: shadow_arg.pat.span, tokens: shadow_arg.pat.tokens.clone(), }); d_inputs.push(shadow_arg.clone()); } } DiffActivity::Const => { // Nothing to do here. } DiffActivity::None | DiffActivity::FakeActivitySize(_) => { panic!("Should not happen"); } } if let PatKind::Ident(_, ident, _) = arg.pat.kind { idents.push(ident.clone()); } else { panic!("not an ident?"); } } let active_only_ret = x.ret_activity == DiffActivity::ActiveOnly; if active_only_ret { assert!(x.mode.is_rev()); } // If we return a scalar in the primal and the scalar is active, // then add it as last arg to the inputs. if x.mode.is_rev() { match x.ret_activity { DiffActivity::Active | DiffActivity::ActiveOnly => { let ty = match d_decl.output { FnRetTy::Ty(ref ty) => ty.clone(), FnRetTy::Default(span) => { panic!("Did not expect Default ret ty: {:?}", span); } }; let name = "dret".to_string(); let ident = Ident::from_str_and_span(&name, ty.span); let shadow_arg = ast::Param { attrs: ThinVec::new(), ty: ty.clone(), pat: P(ast::Pat { id: ast::DUMMY_NODE_ID, kind: PatKind::Ident(BindingMode::NONE, ident, None), span: ty.span, tokens: None, }), id: ast::DUMMY_NODE_ID, span: ty.span, is_placeholder: false, }; d_inputs.push(shadow_arg); new_inputs.push(name); } _ => {} } } d_decl.inputs = d_inputs.into(); if x.mode.is_fwd() { let ty = match d_decl.output { FnRetTy::Ty(ref ty) => ty.clone(), FnRetTy::Default(span) => { // We want to return std::hint::black_box(()). let kind = TyKind::Tup(ThinVec::new()); let ty = P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None }); d_decl.output = FnRetTy::Ty(ty.clone()); assert!(matches!(x.ret_activity, DiffActivity::None)); // this won't be used below, so any type would be fine. ty } }; if matches!(x.ret_activity, DiffActivity::Dual | DiffActivity::Dualv) { let kind = if x.width == 1 || matches!(x.ret_activity, DiffActivity::Dualv) { // Dual can only be used for f32/f64 ret. // In that case we return now a tuple with two floats. TyKind::Tup(thin_vec![ty.clone(), ty.clone()]) } else { // We have to return [T; width+1], +1 for the primal return. let anon_const = rustc_ast::AnonConst { id: ast::DUMMY_NODE_ID, value: ecx.expr_usize(span, 1 + x.width as usize), }; TyKind::Array(ty.clone(), anon_const) }; let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None }); d_decl.output = FnRetTy::Ty(ty); } if matches!(x.ret_activity, DiffActivity::DualOnly | DiffActivity::DualvOnly) { // No need to change the return type, // we will just return the shadow in place of the primal return. // However, if we have a width > 1, then we don't return -> T, but -> [T; width] if x.width > 1 { let anon_const = rustc_ast::AnonConst { id: ast::DUMMY_NODE_ID, value: ecx.expr_usize(span, x.width as usize), }; let kind = TyKind::Array(ty.clone(), anon_const); let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None }); d_decl.output = FnRetTy::Ty(ty); } } } // If we use ActiveOnly, drop the original return value. d_decl.output = if active_only_ret { FnRetTy::Default(span) } else { d_decl.output.clone() }; trace!("act_ret: {:?}", act_ret); // If we have an active input scalar, add it's gradient to the // return type. This might require changing the return type to a // tuple. if act_ret.len() > 0 { let ret_ty = match d_decl.output { FnRetTy::Ty(ref ty) => { if !active_only_ret { act_ret.insert(0, ty.clone()); } let kind = TyKind::Tup(act_ret); P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None }) } FnRetTy::Default(span) => { if act_ret.len() == 1 { act_ret[0].clone() } else { let kind = TyKind::Tup(act_ret.iter().map(|arg| arg.clone()).collect()); P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None }) } } }; d_decl.output = FnRetTy::Ty(ret_ty); } let mut d_header = sig.header.clone(); if unsafe_activities { d_header.safety = rustc_ast::Safety::Unsafe(span); } let d_sig = FnSig { header: d_header, decl: d_decl, span }; trace!("Generated signature: {:?}", d_sig); (d_sig, new_inputs, idents, false) } } pub(crate) use llvm_enzyme::{expand_forward, expand_reverse};