2024-10-11 19:13:31 +02:00
|
|
|
//! This module contains the implementation of the `#[autodiff]` attribute.
|
|
|
|
|
//! Currently our linter isn't smart enough to see that each import is used in one of the two
|
|
|
|
|
//! configs (autodiff enabled or disabled), so we have to add cfg's to each import.
|
|
|
|
|
//! FIXME(ZuseZ4): Remove this once we have a smarter linter.
|
|
|
|
|
|
|
|
|
|
mod llvm_enzyme {
|
|
|
|
|
use std::str::FromStr;
|
|
|
|
|
use std::string::String;
|
|
|
|
|
|
|
|
|
|
use rustc_ast::expand::autodiff_attrs::{
|
2025-03-07 17:37:50 +01:00
|
|
|
AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity, valid_ret_activity,
|
|
|
|
|
valid_ty_for_activity,
|
2024-10-11 19:13:31 +02:00
|
|
|
};
|
|
|
|
|
use rustc_ast::ptr::P;
|
2025-04-03 17:19:11 -04:00
|
|
|
use rustc_ast::token::{Lit, LitKind, Token, TokenKind};
|
2024-10-11 19:13:31 +02:00
|
|
|
use rustc_ast::tokenstream::*;
|
|
|
|
|
use rustc_ast::visit::AssocCtxt::*;
|
|
|
|
|
use rustc_ast::{
|
2025-04-03 17:19:11 -04:00
|
|
|
self as ast, AssocItemKind, BindingMode, ExprKind, FnRetTy, FnSig, Generics, ItemKind,
|
2025-04-03 22:47:30 +02:00
|
|
|
MetaItemInner, PatKind, QSelf, TyKind, Visibility,
|
2024-10-11 19:13:31 +02:00
|
|
|
};
|
|
|
|
|
use rustc_expand::base::{Annotatable, ExtCtxt};
|
2024-12-13 10:29:23 +11:00
|
|
|
use rustc_span::{Ident, Span, Symbol, kw, sym};
|
2024-10-11 19:13:31 +02:00
|
|
|
use thin_vec::{ThinVec, thin_vec};
|
|
|
|
|
use tracing::{debug, trace};
|
|
|
|
|
|
|
|
|
|
use crate::errors;
|
|
|
|
|
|
2025-03-17 17:23:35 -04:00
|
|
|
pub(crate) fn outer_normal_attr(
|
|
|
|
|
kind: &P<rustc_ast::NormalAttr>,
|
|
|
|
|
id: rustc_ast::AttrId,
|
|
|
|
|
span: Span,
|
|
|
|
|
) -> rustc_ast::Attribute {
|
|
|
|
|
let style = rustc_ast::AttrStyle::Outer;
|
|
|
|
|
let kind = rustc_ast::AttrKind::Normal(kind.clone());
|
|
|
|
|
rustc_ast::Attribute { kind, id, style, span }
|
|
|
|
|
}
|
|
|
|
|
|
2024-10-11 19:13:31 +02:00
|
|
|
// If we have a default `()` return type or explicitley `()` return type,
|
|
|
|
|
// then we often can skip doing some work.
|
|
|
|
|
fn has_ret(ty: &FnRetTy) -> bool {
|
|
|
|
|
match ty {
|
|
|
|
|
FnRetTy::Ty(ty) => !ty.kind.is_unit(),
|
|
|
|
|
FnRetTy::Default(_) => false,
|
|
|
|
|
}
|
|
|
|
|
}
|
2024-12-13 10:29:23 +11:00
|
|
|
fn first_ident(x: &MetaItemInner) -> rustc_span::Ident {
|
2025-04-03 17:19:11 -04:00
|
|
|
if let Some(l) = x.lit() {
|
|
|
|
|
match l.kind {
|
|
|
|
|
ast::LitKind::Int(val, _) => {
|
|
|
|
|
// get an Ident from a lit
|
|
|
|
|
return rustc_span::Ident::from_str(val.get().to_string().as_str());
|
|
|
|
|
}
|
|
|
|
|
_ => {}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2024-10-11 19:13:31 +02:00
|
|
|
let segments = &x.meta_item().unwrap().path.segments;
|
|
|
|
|
assert!(segments.len() == 1);
|
|
|
|
|
segments[0].ident
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn name(x: &MetaItemInner) -> String {
|
|
|
|
|
first_ident(x).name.to_string()
|
|
|
|
|
}
|
|
|
|
|
|
2025-04-03 17:19:11 -04:00
|
|
|
fn width(x: &MetaItemInner) -> Option<u128> {
|
|
|
|
|
let lit = x.lit()?;
|
|
|
|
|
match lit.kind {
|
|
|
|
|
ast::LitKind::Int(x, _) => Some(x.get()),
|
|
|
|
|
_ => return None,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2025-04-03 22:47:30 +02:00
|
|
|
// Get information about the function the macro is applied to
|
2025-04-19 19:17:22 +02:00
|
|
|
fn extract_item_info(iitem: &P<ast::Item>) -> Option<(Visibility, FnSig, Ident, Generics)> {
|
2025-04-03 22:47:30 +02:00
|
|
|
match &iitem.kind {
|
2025-04-19 19:17:22 +02:00
|
|
|
ItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => {
|
|
|
|
|
Some((iitem.vis.clone(), sig.clone(), ident.clone(), generics.clone()))
|
2025-04-03 22:47:30 +02:00
|
|
|
}
|
|
|
|
|
_ => None,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2024-10-11 19:13:31 +02:00
|
|
|
pub(crate) fn from_ast(
|
|
|
|
|
ecx: &mut ExtCtxt<'_>,
|
|
|
|
|
meta_item: &ThinVec<MetaItemInner>,
|
|
|
|
|
has_ret: bool,
|
|
|
|
|
) -> AutoDiffAttrs {
|
|
|
|
|
let dcx = ecx.sess.dcx();
|
2025-04-03 17:19:11 -04:00
|
|
|
|
|
|
|
|
// Now we check, whether the user wants autodiff in batch/vector mode, or scalar mode.
|
|
|
|
|
// If he doesn't specify an integer (=width), we default to scalar mode, thus width=1.
|
2025-05-06 09:19:33 +02:00
|
|
|
let mut first_activity = 1;
|
2025-04-03 17:19:11 -04:00
|
|
|
|
2025-05-06 09:19:33 +02:00
|
|
|
let width = if let [_, x, ..] = &meta_item[..]
|
2025-04-03 17:19:11 -04:00
|
|
|
&& let Some(x) = width(x)
|
|
|
|
|
{
|
2025-05-06 09:19:33 +02:00
|
|
|
first_activity = 2;
|
2025-04-03 17:19:11 -04:00
|
|
|
match x.try_into() {
|
|
|
|
|
Ok(x) => x,
|
|
|
|
|
Err(_) => {
|
|
|
|
|
dcx.emit_err(errors::AutoDiffInvalidWidth {
|
2025-05-06 09:19:33 +02:00
|
|
|
span: meta_item[1].span(),
|
2025-04-03 17:19:11 -04:00
|
|
|
width: x,
|
|
|
|
|
});
|
|
|
|
|
return AutoDiffAttrs::error();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
1
|
|
|
|
|
};
|
|
|
|
|
|
2024-10-11 19:13:31 +02:00
|
|
|
let mut activities: Vec<DiffActivity> = vec![];
|
|
|
|
|
let mut errors = false;
|
2025-04-03 17:19:11 -04:00
|
|
|
for x in &meta_item[first_activity..] {
|
2024-10-11 19:13:31 +02:00
|
|
|
let activity_str = name(&x);
|
|
|
|
|
let res = DiffActivity::from_str(&activity_str);
|
|
|
|
|
match res {
|
|
|
|
|
Ok(x) => activities.push(x),
|
|
|
|
|
Err(_) => {
|
|
|
|
|
dcx.emit_err(errors::AutoDiffUnknownActivity {
|
|
|
|
|
span: x.span(),
|
|
|
|
|
act: activity_str,
|
|
|
|
|
});
|
|
|
|
|
errors = true;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
if errors {
|
|
|
|
|
return AutoDiffAttrs::error();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// If a return type exist, we need to split the last activity,
|
|
|
|
|
// otherwise we return None as placeholder.
|
|
|
|
|
let (ret_activity, input_activity) = if has_ret {
|
|
|
|
|
let Some((last, rest)) = activities.split_last() else {
|
|
|
|
|
unreachable!(
|
|
|
|
|
"should not be reachable because we counted the number of activities previously"
|
|
|
|
|
);
|
|
|
|
|
};
|
|
|
|
|
(last, rest)
|
|
|
|
|
} else {
|
|
|
|
|
(&DiffActivity::None, activities.as_slice())
|
|
|
|
|
};
|
|
|
|
|
|
2025-04-03 17:19:11 -04:00
|
|
|
AutoDiffAttrs {
|
2025-05-06 09:19:33 +02:00
|
|
|
mode: DiffMode::Error,
|
2025-04-03 17:19:11 -04:00
|
|
|
width,
|
|
|
|
|
ret_activity: *ret_activity,
|
|
|
|
|
input_activity: input_activity.to_vec(),
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn meta_item_inner_to_ts(t: &MetaItemInner, ts: &mut Vec<TokenTree>) {
|
|
|
|
|
let comma: Token = Token::new(TokenKind::Comma, Span::default());
|
|
|
|
|
let val = first_ident(t);
|
|
|
|
|
let t = Token::from_ast_ident(val);
|
|
|
|
|
ts.push(TokenTree::Token(t, Spacing::Joint));
|
|
|
|
|
ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
|
2024-10-11 19:13:31 +02:00
|
|
|
}
|
|
|
|
|
|
2025-05-06 09:19:33 +02:00
|
|
|
pub(crate) fn expand_forward(
|
|
|
|
|
ecx: &mut ExtCtxt<'_>,
|
|
|
|
|
expand_span: Span,
|
|
|
|
|
meta_item: &ast::MetaItem,
|
|
|
|
|
item: Annotatable,
|
|
|
|
|
) -> Vec<Annotatable> {
|
|
|
|
|
expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Forward)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub(crate) fn expand_reverse(
|
|
|
|
|
ecx: &mut ExtCtxt<'_>,
|
|
|
|
|
expand_span: Span,
|
|
|
|
|
meta_item: &ast::MetaItem,
|
|
|
|
|
item: Annotatable,
|
|
|
|
|
) -> Vec<Annotatable> {
|
|
|
|
|
expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Reverse)
|
|
|
|
|
}
|
|
|
|
|
|
2024-10-11 19:13:31 +02:00
|
|
|
/// We expand the autodiff macro to generate a new placeholder function which passes
|
|
|
|
|
/// type-checking and can be called by users. The function body of the placeholder function will
|
|
|
|
|
/// later be replaced on LLVM-IR level, so the design of the body is less important and for now
|
|
|
|
|
/// should just prevent early inlining and optimizations which alter the function signature.
|
|
|
|
|
/// The exact signature of the generated function depends on the configuration provided by the
|
|
|
|
|
/// user, but here is an example:
|
|
|
|
|
///
|
|
|
|
|
/// ```
|
|
|
|
|
/// #[autodiff(cos_box, Reverse, Duplicated, Active)]
|
|
|
|
|
/// fn sin(x: &Box<f32>) -> f32 {
|
|
|
|
|
/// f32::sin(**x)
|
|
|
|
|
/// }
|
|
|
|
|
/// ```
|
|
|
|
|
/// which becomes expanded to:
|
|
|
|
|
/// ```
|
|
|
|
|
/// #[rustc_autodiff]
|
|
|
|
|
/// #[inline(never)]
|
|
|
|
|
/// fn sin(x: &Box<f32>) -> f32 {
|
|
|
|
|
/// f32::sin(**x)
|
|
|
|
|
/// }
|
|
|
|
|
/// #[rustc_autodiff(Reverse, Duplicated, Active)]
|
|
|
|
|
/// #[inline(never)]
|
|
|
|
|
/// fn cos_box(x: &Box<f32>, dx: &mut Box<f32>, dret: f32) -> f32 {
|
|
|
|
|
/// unsafe {
|
|
|
|
|
/// asm!("NOP");
|
|
|
|
|
/// };
|
|
|
|
|
/// ::core::hint::black_box(sin(x));
|
|
|
|
|
/// ::core::hint::black_box((dx, dret));
|
|
|
|
|
/// ::core::hint::black_box(sin(x))
|
|
|
|
|
/// }
|
|
|
|
|
/// ```
|
|
|
|
|
/// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked
|
|
|
|
|
/// in CI.
|
2025-05-06 09:19:33 +02:00
|
|
|
pub(crate) fn expand_with_mode(
|
2024-10-11 19:13:31 +02:00
|
|
|
ecx: &mut ExtCtxt<'_>,
|
|
|
|
|
expand_span: Span,
|
|
|
|
|
meta_item: &ast::MetaItem,
|
|
|
|
|
mut item: Annotatable,
|
2025-05-06 09:19:33 +02:00
|
|
|
mode: DiffMode,
|
2024-10-11 19:13:31 +02:00
|
|
|
) -> Vec<Annotatable> {
|
2025-02-27 19:32:30 +05:30
|
|
|
if cfg!(not(llvm_enzyme)) {
|
|
|
|
|
ecx.sess.dcx().emit_err(errors::AutoDiffSupportNotBuild { span: meta_item.span });
|
|
|
|
|
return vec![item];
|
|
|
|
|
}
|
2024-10-11 19:13:31 +02:00
|
|
|
let dcx = ecx.sess.dcx();
|
2025-03-09 22:55:07 +01:00
|
|
|
|
2025-04-19 19:17:22 +02:00
|
|
|
// first get information about the annotable item: visibility, signature, name and generic
|
|
|
|
|
// parameters.
|
|
|
|
|
// these will be used to generate the differentiated version of the function
|
|
|
|
|
let Some((vis, sig, primal, generics)) = (match &item {
|
2025-04-03 22:47:30 +02:00
|
|
|
Annotatable::Item(iitem) => extract_item_info(iitem),
|
|
|
|
|
Annotatable::Stmt(stmt) => match &stmt.kind {
|
|
|
|
|
ast::StmtKind::Item(iitem) => extract_item_info(iitem),
|
|
|
|
|
_ => None,
|
|
|
|
|
},
|
2025-04-21 13:22:56 +05:30
|
|
|
Annotatable::AssocItem(assoc_item, Impl { .. }) => match &assoc_item.kind {
|
2025-04-19 19:17:22 +02:00
|
|
|
ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => {
|
|
|
|
|
Some((assoc_item.vis.clone(), sig.clone(), ident.clone(), generics.clone()))
|
2025-04-03 22:47:30 +02:00
|
|
|
}
|
2025-04-21 13:22:56 +05:30
|
|
|
_ => None,
|
|
|
|
|
},
|
2025-04-03 22:47:30 +02:00
|
|
|
_ => None,
|
|
|
|
|
}) else {
|
|
|
|
|
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
|
|
|
|
|
return vec![item];
|
2024-10-11 19:13:31 +02:00
|
|
|
};
|
|
|
|
|
|
|
|
|
|
let meta_item_vec: ThinVec<MetaItemInner> = match meta_item.kind {
|
|
|
|
|
ast::MetaItemKind::List(ref vec) => vec.clone(),
|
|
|
|
|
_ => {
|
2025-04-08 21:54:34 -04:00
|
|
|
dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
|
2024-10-11 19:13:31 +02:00
|
|
|
return vec![item];
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
let has_ret = has_ret(&sig.decl.output);
|
|
|
|
|
let sig_span = ecx.with_call_site_ctxt(sig.span);
|
|
|
|
|
|
|
|
|
|
// create TokenStream from vec elemtents:
|
|
|
|
|
// meta_item doesn't have a .tokens field
|
|
|
|
|
let mut ts: Vec<TokenTree> = vec![];
|
|
|
|
|
if meta_item_vec.len() < 2 {
|
|
|
|
|
// At the bare minimum, we need a fnc name and a mode, even for a dummy function with no
|
|
|
|
|
// input and output args.
|
|
|
|
|
dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
|
|
|
|
|
return vec![item];
|
2025-04-03 17:19:11 -04:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
meta_item_inner_to_ts(&meta_item_vec[1], &mut ts);
|
|
|
|
|
|
|
|
|
|
// Now, if the user gave a width (vector aka batch-mode ad), then we copy it.
|
|
|
|
|
// If it is not given, we default to 1 (scalar mode).
|
|
|
|
|
let start_position;
|
|
|
|
|
let kind: LitKind = LitKind::Integer;
|
|
|
|
|
let symbol;
|
|
|
|
|
if meta_item_vec.len() >= 3
|
|
|
|
|
&& let Some(width) = width(&meta_item_vec[2])
|
|
|
|
|
{
|
|
|
|
|
start_position = 3;
|
|
|
|
|
symbol = Symbol::intern(&width.to_string());
|
2024-10-11 19:13:31 +02:00
|
|
|
} else {
|
2025-04-03 17:19:11 -04:00
|
|
|
start_position = 2;
|
|
|
|
|
symbol = sym::integer(1);
|
2024-10-11 19:13:31 +02:00
|
|
|
}
|
2025-04-03 17:19:11 -04:00
|
|
|
let l: Lit = Lit { kind, symbol, suffix: None };
|
|
|
|
|
let t = Token::new(TokenKind::Literal(l), Span::default());
|
|
|
|
|
let comma = Token::new(TokenKind::Comma, Span::default());
|
|
|
|
|
ts.push(TokenTree::Token(t, Spacing::Joint));
|
|
|
|
|
ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
|
|
|
|
|
|
|
|
|
|
for t in meta_item_vec.clone()[start_position..].iter() {
|
|
|
|
|
meta_item_inner_to_ts(t, &mut ts);
|
|
|
|
|
}
|
|
|
|
|
|
2024-10-11 19:13:31 +02:00
|
|
|
if !has_ret {
|
|
|
|
|
// We don't want users to provide a return activity if the function doesn't return anything.
|
|
|
|
|
// For simplicity, we just add a dummy token to the end of the list.
|
|
|
|
|
let t = Token::new(TokenKind::Ident(sym::None, false.into()), Span::default());
|
|
|
|
|
ts.push(TokenTree::Token(t, Spacing::Joint));
|
2025-04-03 17:19:11 -04:00
|
|
|
ts.push(TokenTree::Token(comma, Spacing::Alone));
|
2024-10-11 19:13:31 +02:00
|
|
|
}
|
2025-04-03 17:19:11 -04:00
|
|
|
// We remove the last, trailing comma.
|
|
|
|
|
ts.pop();
|
2024-10-11 19:13:31 +02:00
|
|
|
let ts: TokenStream = TokenStream::from_iter(ts);
|
|
|
|
|
|
2025-05-06 09:19:33 +02:00
|
|
|
let mut x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret);
|
|
|
|
|
x.mode = mode;
|
2024-10-11 19:13:31 +02:00
|
|
|
if !x.is_active() {
|
|
|
|
|
// We encountered an error, so we return the original item.
|
|
|
|
|
// This allows us to potentially parse other attributes.
|
|
|
|
|
return vec![item];
|
|
|
|
|
}
|
|
|
|
|
let span = ecx.with_def_site_ctxt(expand_span);
|
|
|
|
|
|
|
|
|
|
let n_active: u32 = x
|
|
|
|
|
.input_activity
|
|
|
|
|
.iter()
|
|
|
|
|
.filter(|a| **a == DiffActivity::Active || **a == DiffActivity::ActiveOnly)
|
|
|
|
|
.count() as u32;
|
|
|
|
|
let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span);
|
|
|
|
|
let d_body = gen_enzyme_body(
|
2025-03-18 02:47:37 -04:00
|
|
|
ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored,
|
2025-04-20 01:10:50 +02:00
|
|
|
&generics,
|
2024-10-11 19:13:31 +02:00
|
|
|
);
|
|
|
|
|
|
|
|
|
|
// The first element of it is the name of the function to be generated
|
|
|
|
|
let asdf = Box::new(ast::Fn {
|
|
|
|
|
defaultness: ast::Defaultness::Final,
|
|
|
|
|
sig: d_sig,
|
Move `ast::Item::ident` into `ast::ItemKind`.
`ast::Item` has an `ident` field.
- It's always non-empty for these item kinds: `ExternCrate`, `Static`,
`Const`, `Fn`, `Mod`, `TyAlias`, `Enum`, `Struct`, `Union`,
`Trait`, `TraitAlias`, `MacroDef`, `Delegation`.
- It's always empty for these item kinds: `Use`, `ForeignMod`,
`GlobalAsm`, `Impl`, `MacCall`, `DelegationMac`.
There is a similar story for `AssocItemKind` and `ForeignItemKind`.
Some sites that handle items check for an empty ident, some don't. This
is a very C-like way of doing things, but this is Rust, we have sum
types, we can do this properly and never forget to check for the
exceptional case and never YOLO possibly empty identifiers (or possibly
dummy spans) around and hope that things will work out.
The commit is large but it's mostly obvious plumbing work. Some notable
things.
- `ast::Item` got 8 bytes bigger. This could be avoided by boxing the
fields within some of the `ast::ItemKind` variants (specifically:
`Struct`, `Union`, `Enum`). I might do that in a follow-up; this
commit is big enough already.
- For the visitors: `FnKind` no longer needs an `ident` field because
the `Fn` within how has one.
- In the parser, the `ItemInfo` typedef is no longer needed. It was used
in various places to return an `Ident` alongside an `ItemKind`, but
now the `Ident` (if present) is within the `ItemKind`.
- In a few places I renamed identifier variables called `name` (or
`foo_name`) as `ident` (or `foo_ident`), to better match the type, and
because `name` is normally used for `Symbol`s. It's confusing to see
something like `foo_name.name`.
2025-03-21 09:47:43 +11:00
|
|
|
ident: first_ident(&meta_item_vec[0]),
|
2025-04-19 19:17:22 +02:00
|
|
|
generics,
|
2025-02-21 21:45:29 -05:00
|
|
|
contract: None,
|
2024-10-11 19:13:31 +02:00
|
|
|
body: Some(d_body),
|
2024-07-26 10:04:02 +00:00
|
|
|
define_opaque: None,
|
2024-10-11 19:13:31 +02:00
|
|
|
});
|
|
|
|
|
let mut rustc_ad_attr =
|
|
|
|
|
P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_autodiff)));
|
|
|
|
|
|
|
|
|
|
let ts2: Vec<TokenTree> = vec![TokenTree::Token(
|
|
|
|
|
Token::new(TokenKind::Ident(sym::never, false.into()), span),
|
|
|
|
|
Spacing::Joint,
|
|
|
|
|
)];
|
|
|
|
|
let never_arg = ast::DelimArgs {
|
2024-05-17 17:31:34 +10:00
|
|
|
dspan: DelimSpan::from_single(span),
|
2024-10-11 19:13:31 +02:00
|
|
|
delim: ast::token::Delimiter::Parenthesis,
|
2024-05-17 17:31:34 +10:00
|
|
|
tokens: TokenStream::from_iter(ts2),
|
2024-10-11 19:13:31 +02:00
|
|
|
};
|
|
|
|
|
let inline_item = ast::AttrItem {
|
|
|
|
|
unsafety: ast::Safety::Default,
|
|
|
|
|
path: ast::Path::from_ident(Ident::with_dummy_span(sym::inline)),
|
|
|
|
|
args: ast::AttrArgs::Delimited(never_arg),
|
|
|
|
|
tokens: None,
|
|
|
|
|
};
|
|
|
|
|
let inline_never_attr = P(ast::NormalAttr { item: inline_item, tokens: None });
|
|
|
|
|
let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
|
2025-03-17 17:23:35 -04:00
|
|
|
let attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
|
2024-10-11 19:13:31 +02:00
|
|
|
let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
|
2025-03-17 17:23:35 -04:00
|
|
|
let inline_never = outer_normal_attr(&inline_never_attr, new_id, span);
|
2024-10-11 19:13:31 +02:00
|
|
|
|
2025-03-17 17:06:26 -04:00
|
|
|
// We're avoid duplicating the attributes `#[rustc_autodiff]` and `#[inline(never)]`.
|
|
|
|
|
fn same_attribute(attr: &ast::AttrKind, item: &ast::AttrKind) -> bool {
|
|
|
|
|
match (attr, item) {
|
|
|
|
|
(ast::AttrKind::Normal(a), ast::AttrKind::Normal(b)) => {
|
|
|
|
|
let a = &a.item.path;
|
|
|
|
|
let b = &b.item.path;
|
|
|
|
|
a.segments.len() == b.segments.len()
|
|
|
|
|
&& a.segments.iter().zip(b.segments.iter()).all(|(a, b)| a.ident == b.ident)
|
|
|
|
|
}
|
|
|
|
|
_ => false,
|
|
|
|
|
}
|
|
|
|
|
}
|
2024-10-11 19:13:31 +02:00
|
|
|
|
|
|
|
|
// Don't add it multiple times:
|
|
|
|
|
let orig_annotatable: Annotatable = match item {
|
|
|
|
|
Annotatable::Item(ref mut iitem) => {
|
2025-03-17 17:06:26 -04:00
|
|
|
if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
|
2025-03-07 21:48:54 +01:00
|
|
|
iitem.attrs.push(attr);
|
2024-10-11 19:13:31 +02:00
|
|
|
}
|
2025-03-17 17:06:26 -04:00
|
|
|
if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
|
2024-10-11 19:13:31 +02:00
|
|
|
iitem.attrs.push(inline_never.clone());
|
|
|
|
|
}
|
|
|
|
|
Annotatable::Item(iitem.clone())
|
|
|
|
|
}
|
2025-04-21 13:22:56 +05:30
|
|
|
Annotatable::AssocItem(ref mut assoc_item, i @ Impl { .. }) => {
|
2025-03-17 17:06:26 -04:00
|
|
|
if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
|
2025-03-07 21:48:54 +01:00
|
|
|
assoc_item.attrs.push(attr);
|
2024-10-11 19:13:31 +02:00
|
|
|
}
|
2025-03-17 17:06:26 -04:00
|
|
|
if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
|
2024-10-11 19:13:31 +02:00
|
|
|
assoc_item.attrs.push(inline_never.clone());
|
|
|
|
|
}
|
|
|
|
|
Annotatable::AssocItem(assoc_item.clone(), i)
|
|
|
|
|
}
|
2025-03-09 22:55:07 +01:00
|
|
|
Annotatable::Stmt(ref mut stmt) => {
|
|
|
|
|
match stmt.kind {
|
|
|
|
|
ast::StmtKind::Item(ref mut iitem) => {
|
|
|
|
|
if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
|
|
|
|
|
iitem.attrs.push(attr);
|
|
|
|
|
}
|
|
|
|
|
if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind))
|
|
|
|
|
{
|
|
|
|
|
iitem.attrs.push(inline_never.clone());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
_ => unreachable!("stmt kind checked previously"),
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
Annotatable::Stmt(stmt.clone())
|
|
|
|
|
}
|
2024-10-11 19:13:31 +02:00
|
|
|
_ => {
|
|
|
|
|
unreachable!("annotatable kind checked previously")
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
// Now update for d_fn
|
|
|
|
|
rustc_ad_attr.item.args = rustc_ast::AttrArgs::Delimited(rustc_ast::DelimArgs {
|
|
|
|
|
dspan: DelimSpan::dummy(),
|
|
|
|
|
delim: rustc_ast::token::Delimiter::Parenthesis,
|
|
|
|
|
tokens: ts,
|
|
|
|
|
});
|
2025-03-09 22:55:07 +01:00
|
|
|
|
2025-03-17 17:23:35 -04:00
|
|
|
let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
|
2025-03-09 22:55:07 +01:00
|
|
|
let d_annotatable = match &item {
|
|
|
|
|
Annotatable::AssocItem(_, _) => {
|
|
|
|
|
let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
|
|
|
|
|
let d_fn = P(ast::AssocItem {
|
|
|
|
|
attrs: thin_vec![d_attr, inline_never],
|
|
|
|
|
id: ast::DUMMY_NODE_ID,
|
|
|
|
|
span,
|
|
|
|
|
vis,
|
|
|
|
|
kind: assoc_item,
|
|
|
|
|
tokens: None,
|
|
|
|
|
});
|
|
|
|
|
Annotatable::AssocItem(d_fn, Impl { of_trait: false })
|
|
|
|
|
}
|
|
|
|
|
Annotatable::Item(_) => {
|
|
|
|
|
let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
|
|
|
|
|
d_fn.vis = vis;
|
|
|
|
|
|
|
|
|
|
Annotatable::Item(d_fn)
|
|
|
|
|
}
|
|
|
|
|
Annotatable::Stmt(_) => {
|
|
|
|
|
let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
|
|
|
|
|
d_fn.vis = vis;
|
|
|
|
|
|
|
|
|
|
Annotatable::Stmt(P(ast::Stmt {
|
|
|
|
|
id: ast::DUMMY_NODE_ID,
|
|
|
|
|
kind: ast::StmtKind::Item(d_fn),
|
|
|
|
|
span,
|
|
|
|
|
}))
|
|
|
|
|
}
|
|
|
|
|
_ => {
|
|
|
|
|
unreachable!("item kind checked previously")
|
|
|
|
|
}
|
2024-10-11 19:13:31 +02:00
|
|
|
};
|
|
|
|
|
|
|
|
|
|
return vec![orig_annotatable, d_annotatable];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// shadow arguments (the extra ones which were not in the original (primal) function), in reverse mode must be
|
|
|
|
|
// mutable references or ptrs, because Enzyme will write into them.
|
|
|
|
|
fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty {
|
|
|
|
|
let mut ty = ty.clone();
|
|
|
|
|
match ty.kind {
|
|
|
|
|
TyKind::Ptr(ref mut mut_ty) => {
|
|
|
|
|
mut_ty.mutbl = ast::Mutability::Mut;
|
|
|
|
|
}
|
|
|
|
|
TyKind::Ref(_, ref mut mut_ty) => {
|
|
|
|
|
mut_ty.mutbl = ast::Mutability::Mut;
|
|
|
|
|
}
|
|
|
|
|
_ => {
|
|
|
|
|
panic!("unsupported type: {:?}", ty);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
ty
|
|
|
|
|
}
|
|
|
|
|
|
2025-03-17 16:54:41 -04:00
|
|
|
// Will generate a body of the type:
|
|
|
|
|
// ```
|
|
|
|
|
// {
|
|
|
|
|
// unsafe {
|
|
|
|
|
// asm!("NOP");
|
|
|
|
|
// }
|
|
|
|
|
// ::core::hint::black_box(primal(args));
|
|
|
|
|
// ::core::hint::black_box((args, ret));
|
|
|
|
|
// <This part remains to be done by following function>
|
|
|
|
|
// }
|
|
|
|
|
// ```
|
|
|
|
|
fn init_body_helper(
|
2024-10-11 19:13:31 +02:00
|
|
|
ecx: &ExtCtxt<'_>,
|
2025-03-17 16:54:41 -04:00
|
|
|
span: Span,
|
2024-10-11 19:13:31 +02:00
|
|
|
primal: Ident,
|
|
|
|
|
new_names: &[String],
|
|
|
|
|
sig_span: Span,
|
|
|
|
|
new_decl_span: Span,
|
2025-03-17 16:54:41 -04:00
|
|
|
idents: &[Ident],
|
2024-10-11 19:13:31 +02:00
|
|
|
errored: bool,
|
2025-04-20 01:10:50 +02:00
|
|
|
generics: &Generics,
|
2025-03-17 16:54:41 -04:00
|
|
|
) -> (P<ast::Block>, P<ast::Expr>, P<ast::Expr>, P<ast::Expr>) {
|
2024-10-11 19:13:31 +02:00
|
|
|
let blackbox_path = ecx.std_path(&[sym::hint, sym::black_box]);
|
|
|
|
|
let noop = ast::InlineAsm {
|
|
|
|
|
asm_macro: ast::AsmMacro::Asm,
|
|
|
|
|
template: vec![ast::InlineAsmTemplatePiece::String("NOP".into())],
|
|
|
|
|
template_strs: Box::new([]),
|
|
|
|
|
operands: vec![],
|
|
|
|
|
clobber_abis: vec![],
|
|
|
|
|
options: ast::InlineAsmOptions::PURE | ast::InlineAsmOptions::NOMEM,
|
|
|
|
|
line_spans: vec![],
|
|
|
|
|
};
|
|
|
|
|
let noop_expr = ecx.expr_asm(span, P(noop));
|
|
|
|
|
let unsf = ast::BlockCheckMode::Unsafe(ast::UnsafeSource::CompilerGenerated);
|
|
|
|
|
let unsf_block = ast::Block {
|
|
|
|
|
stmts: thin_vec![ecx.stmt_semi(noop_expr)],
|
|
|
|
|
id: ast::DUMMY_NODE_ID,
|
|
|
|
|
tokens: None,
|
|
|
|
|
rules: unsf,
|
|
|
|
|
span,
|
|
|
|
|
};
|
|
|
|
|
let unsf_expr = ecx.expr_block(P(unsf_block));
|
|
|
|
|
let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path));
|
2025-04-20 01:10:50 +02:00
|
|
|
let primal_call = gen_primal_call(ecx, span, primal, idents, generics);
|
2024-10-11 19:13:31 +02:00
|
|
|
let black_box_primal_call = ecx.expr_call(
|
|
|
|
|
new_decl_span,
|
|
|
|
|
blackbox_call_expr.clone(),
|
|
|
|
|
thin_vec![primal_call.clone()],
|
|
|
|
|
);
|
|
|
|
|
let tup_args = new_names
|
|
|
|
|
.iter()
|
|
|
|
|
.map(|arg| ecx.expr_path(ecx.path_ident(span, Ident::from_str(arg))))
|
|
|
|
|
.collect();
|
|
|
|
|
|
|
|
|
|
let black_box_remaining_args = ecx.expr_call(
|
|
|
|
|
sig_span,
|
|
|
|
|
blackbox_call_expr.clone(),
|
|
|
|
|
thin_vec![ecx.expr_tuple(sig_span, tup_args)],
|
|
|
|
|
);
|
|
|
|
|
|
|
|
|
|
let mut body = ecx.block(span, ThinVec::new());
|
|
|
|
|
body.stmts.push(ecx.stmt_semi(unsf_expr));
|
|
|
|
|
|
|
|
|
|
// This uses primal args which won't be available if we errored before
|
|
|
|
|
if !errored {
|
|
|
|
|
body.stmts.push(ecx.stmt_semi(black_box_primal_call.clone()));
|
|
|
|
|
}
|
|
|
|
|
body.stmts.push(ecx.stmt_semi(black_box_remaining_args));
|
|
|
|
|
|
2025-03-17 16:54:41 -04:00
|
|
|
(body, primal_call, black_box_primal_call, blackbox_call_expr)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// We only want this function to type-check, since we will replace the body
|
|
|
|
|
/// later on llvm level. Using `loop {}` does not cover all return types anymore,
|
2025-03-18 02:47:37 -04:00
|
|
|
/// so instead we manually build something that should pass the type checker.
|
|
|
|
|
/// We also add a inline_asm line, as one more barrier for rustc to prevent inlining
|
|
|
|
|
/// or const propagation. inline_asm will also triggers an Enzyme crash if due to another
|
|
|
|
|
/// bug would ever try to accidentially differentiate this placeholder function body.
|
2025-03-17 16:54:41 -04:00
|
|
|
/// Finally, we also add back_box usages of all input arguments, to prevent rustc
|
|
|
|
|
/// from optimizing any arguments away.
|
|
|
|
|
fn gen_enzyme_body(
|
|
|
|
|
ecx: &ExtCtxt<'_>,
|
|
|
|
|
x: &AutoDiffAttrs,
|
|
|
|
|
n_active: u32,
|
|
|
|
|
sig: &ast::FnSig,
|
|
|
|
|
d_sig: &ast::FnSig,
|
|
|
|
|
primal: Ident,
|
|
|
|
|
new_names: &[String],
|
|
|
|
|
span: Span,
|
|
|
|
|
sig_span: Span,
|
|
|
|
|
idents: Vec<Ident>,
|
|
|
|
|
errored: bool,
|
2025-04-20 01:10:50 +02:00
|
|
|
generics: &Generics,
|
2025-03-17 16:54:41 -04:00
|
|
|
) -> P<ast::Block> {
|
|
|
|
|
let new_decl_span = d_sig.span;
|
|
|
|
|
|
|
|
|
|
// Just adding some default inline-asm and black_box usages to prevent early inlining
|
|
|
|
|
// and optimizations which alter the function signature.
|
|
|
|
|
//
|
|
|
|
|
// The bb_primal_call is the black_box call of the primal function. We keep it around,
|
|
|
|
|
// since it has the convenient property of returning the type of the primal function,
|
|
|
|
|
// Remember, we only care to match types here.
|
|
|
|
|
// No matter which return we pick, we always wrap it into a std::hint::black_box call,
|
|
|
|
|
// to prevent rustc from propagating it into the caller.
|
|
|
|
|
let (mut body, primal_call, bb_primal_call, bb_call_expr) = init_body_helper(
|
|
|
|
|
ecx,
|
|
|
|
|
span,
|
|
|
|
|
primal,
|
|
|
|
|
new_names,
|
|
|
|
|
sig_span,
|
|
|
|
|
new_decl_span,
|
|
|
|
|
&idents,
|
|
|
|
|
errored,
|
2025-04-20 01:10:50 +02:00
|
|
|
generics,
|
2025-03-17 16:54:41 -04:00
|
|
|
);
|
|
|
|
|
|
2024-10-11 19:13:31 +02:00
|
|
|
if !has_ret(&d_sig.decl.output) {
|
|
|
|
|
// there is no return type that we have to match, () works fine.
|
|
|
|
|
return body;
|
|
|
|
|
}
|
|
|
|
|
|
2025-04-03 17:19:11 -04:00
|
|
|
// Everything from here onwards just tries to fullfil the return type. Fun!
|
|
|
|
|
|
2024-10-11 19:13:31 +02:00
|
|
|
// having an active-only return means we'll drop the original return type.
|
|
|
|
|
// So that can be treated identical to not having one in the first place.
|
|
|
|
|
let primal_ret = has_ret(&sig.decl.output) && !x.has_active_only_ret();
|
|
|
|
|
|
|
|
|
|
if primal_ret && n_active == 0 && x.mode.is_rev() {
|
|
|
|
|
// We only have the primal ret.
|
2025-03-17 16:54:41 -04:00
|
|
|
body.stmts.push(ecx.stmt_expr(bb_primal_call));
|
2024-10-11 19:13:31 +02:00
|
|
|
return body;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if !primal_ret && n_active == 1 {
|
|
|
|
|
// Again no tuple return, so return default float val.
|
|
|
|
|
let ty = match d_sig.decl.output {
|
|
|
|
|
FnRetTy::Ty(ref ty) => ty.clone(),
|
|
|
|
|
FnRetTy::Default(span) => {
|
|
|
|
|
panic!("Did not expect Default ret ty: {:?}", span);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
let arg = ty.kind.is_simple_path().unwrap();
|
2025-04-19 18:11:57 +02:00
|
|
|
let tmp = ecx.def_site_path(&[arg, kw::Default]);
|
2024-10-11 19:13:31 +02:00
|
|
|
let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
|
|
|
|
|
let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
|
|
|
|
|
body.stmts.push(ecx.stmt_expr(default_call_expr));
|
|
|
|
|
return body;
|
|
|
|
|
}
|
|
|
|
|
|
2025-04-19 18:11:57 +02:00
|
|
|
let mut exprs: P<ast::Expr> = primal_call;
|
2024-10-11 19:13:31 +02:00
|
|
|
let d_ret_ty = match d_sig.decl.output {
|
|
|
|
|
FnRetTy::Ty(ref ty) => ty.clone(),
|
|
|
|
|
FnRetTy::Default(span) => {
|
|
|
|
|
panic!("Did not expect Default ret ty: {:?}", span);
|
|
|
|
|
}
|
|
|
|
|
};
|
2025-04-03 17:19:11 -04:00
|
|
|
if x.mode.is_fwd() {
|
|
|
|
|
// Fwd mode is easy. If the return activity is Const, we support arbitrary types.
|
|
|
|
|
// Otherwise, we only support a scalar, a pair of scalars, or an array of scalars.
|
|
|
|
|
// We checked that (on a best-effort base) in the preceding gen_enzyme_decl function.
|
|
|
|
|
// In all three cases, we can return `std::hint::black_box(<T>::default())`.
|
|
|
|
|
if x.ret_activity == DiffActivity::Const {
|
|
|
|
|
// Here we call the primal function, since our dummy function has the same return
|
|
|
|
|
// type due to the Const return activity.
|
|
|
|
|
exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![exprs]);
|
|
|
|
|
} else {
|
2025-04-19 18:11:57 +02:00
|
|
|
let q = QSelf { ty: d_ret_ty, path_span: span, position: 0 };
|
2025-04-03 17:19:11 -04:00
|
|
|
let y =
|
|
|
|
|
ExprKind::Path(Some(P(q)), ecx.path_ident(span, Ident::from_str("default")));
|
|
|
|
|
let default_call_expr = ecx.expr(span, y);
|
2024-10-11 19:13:31 +02:00
|
|
|
let default_call_expr =
|
|
|
|
|
ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
|
2025-04-03 17:19:11 -04:00
|
|
|
exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![default_call_expr]);
|
2024-10-11 19:13:31 +02:00
|
|
|
}
|
2025-04-03 17:19:11 -04:00
|
|
|
} else if x.mode.is_rev() {
|
|
|
|
|
if x.width == 1 {
|
|
|
|
|
// We either have `-> ArbitraryType` or `-> (ArbitraryType, repeated_float_scalars)`.
|
|
|
|
|
match d_ret_ty.kind {
|
|
|
|
|
TyKind::Tup(ref args) => {
|
|
|
|
|
// We have a tuple return type. We need to create a tuple of the same size
|
|
|
|
|
// and fill it with default values.
|
|
|
|
|
let mut exprs2 = thin_vec![exprs];
|
|
|
|
|
for arg in args.iter().skip(1) {
|
|
|
|
|
let arg = arg.kind.is_simple_path().unwrap();
|
2025-04-19 18:11:57 +02:00
|
|
|
let tmp = ecx.def_site_path(&[arg, kw::Default]);
|
2025-04-03 17:19:11 -04:00
|
|
|
let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
|
|
|
|
|
let default_call_expr =
|
|
|
|
|
ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
|
|
|
|
|
exprs2.push(default_call_expr);
|
|
|
|
|
}
|
|
|
|
|
exprs = ecx.expr_tuple(new_decl_span, exprs2);
|
|
|
|
|
}
|
|
|
|
|
_ => {
|
|
|
|
|
// Interestingly, even the `-> ArbitraryType` case
|
|
|
|
|
// ends up getting matched and handled correctly above,
|
|
|
|
|
// so we don't have to handle any other case for now.
|
|
|
|
|
panic!("Unsupported return type: {:?}", d_ret_ty);
|
|
|
|
|
}
|
|
|
|
|
}
|
2024-10-11 19:13:31 +02:00
|
|
|
}
|
2025-04-03 17:19:11 -04:00
|
|
|
exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![exprs]);
|
|
|
|
|
} else {
|
|
|
|
|
unreachable!("Unsupported mode: {:?}", x.mode);
|
2024-10-11 19:13:31 +02:00
|
|
|
}
|
2025-04-03 17:19:11 -04:00
|
|
|
|
|
|
|
|
body.stmts.push(ecx.stmt_expr(exprs));
|
2024-10-11 19:13:31 +02:00
|
|
|
|
|
|
|
|
body
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn gen_primal_call(
|
|
|
|
|
ecx: &ExtCtxt<'_>,
|
|
|
|
|
span: Span,
|
|
|
|
|
primal: Ident,
|
2025-03-17 16:54:41 -04:00
|
|
|
idents: &[Ident],
|
2025-04-20 01:10:50 +02:00
|
|
|
generics: &Generics,
|
2024-10-11 19:13:31 +02:00
|
|
|
) -> P<ast::Expr> {
|
|
|
|
|
let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower;
|
2025-04-20 01:10:50 +02:00
|
|
|
|
2024-10-11 19:13:31 +02:00
|
|
|
if has_self {
|
|
|
|
|
let args: ThinVec<_> =
|
|
|
|
|
idents[1..].iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
|
|
|
|
|
let self_expr = ecx.expr_self(span);
|
2025-03-07 21:48:54 +01:00
|
|
|
ecx.expr_method_call(span, self_expr, primal, args)
|
2024-10-11 19:13:31 +02:00
|
|
|
} else {
|
|
|
|
|
let args: ThinVec<_> =
|
|
|
|
|
idents.iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
|
2025-04-20 01:10:50 +02:00
|
|
|
let mut primal_path = ecx.path_ident(span, primal);
|
|
|
|
|
|
|
|
|
|
let is_generic = !generics.params.is_empty();
|
|
|
|
|
|
|
|
|
|
match (is_generic, primal_path.segments.last_mut()) {
|
|
|
|
|
(true, Some(function_path)) => {
|
|
|
|
|
let primal_generic_types = generics
|
|
|
|
|
.params
|
|
|
|
|
.iter()
|
|
|
|
|
.filter(|param| matches!(param.kind, ast::GenericParamKind::Type { .. }));
|
|
|
|
|
|
|
|
|
|
let generated_generic_types = primal_generic_types
|
|
|
|
|
.map(|type_param| {
|
|
|
|
|
let generic_param = TyKind::Path(
|
|
|
|
|
None,
|
|
|
|
|
ast::Path {
|
|
|
|
|
span,
|
|
|
|
|
segments: thin_vec![ast::PathSegment {
|
|
|
|
|
ident: type_param.ident,
|
|
|
|
|
args: None,
|
|
|
|
|
id: ast::DUMMY_NODE_ID,
|
|
|
|
|
}],
|
|
|
|
|
tokens: None,
|
|
|
|
|
},
|
|
|
|
|
);
|
|
|
|
|
|
|
|
|
|
ast::AngleBracketedArg::Arg(ast::GenericArg::Type(P(ast::Ty {
|
|
|
|
|
id: type_param.id,
|
|
|
|
|
span,
|
|
|
|
|
kind: generic_param,
|
|
|
|
|
tokens: None,
|
|
|
|
|
})))
|
|
|
|
|
})
|
|
|
|
|
.collect();
|
|
|
|
|
|
|
|
|
|
function_path.args =
|
|
|
|
|
Some(P(ast::GenericArgs::AngleBracketed(ast::AngleBracketedArgs {
|
|
|
|
|
span,
|
|
|
|
|
args: generated_generic_types,
|
|
|
|
|
})));
|
|
|
|
|
}
|
|
|
|
|
_ => {}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let primal_call_expr = ecx.expr_path(primal_path);
|
2024-10-11 19:13:31 +02:00
|
|
|
ecx.expr_call(span, primal_call_expr, args)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Generate the new function declaration. Const arguments are kept as is. Duplicated arguments must
|
|
|
|
|
// be pointers or references. Those receive a shadow argument, which is a mutable reference/pointer.
|
|
|
|
|
// Active arguments must be scalars. Their shadow argument is added to the return type (and will be
|
|
|
|
|
// zero-initialized by Enzyme).
|
|
|
|
|
// Each argument of the primal function (and the return type if existing) must be annotated with an
|
|
|
|
|
// activity.
|
|
|
|
|
//
|
|
|
|
|
// Error handling: If the user provides an invalid configuration (incorrect numbers, types, or
|
|
|
|
|
// both), we emit an error and return the original signature. This allows us to continue parsing.
|
2025-03-10 16:05:27 +01:00
|
|
|
// FIXME(Sa4dUs): make individual activities' span available so errors
|
|
|
|
|
// can point to only the activity instead of the entire attribute
|
2024-10-11 19:13:31 +02:00
|
|
|
fn gen_enzyme_decl(
|
|
|
|
|
ecx: &ExtCtxt<'_>,
|
|
|
|
|
sig: &ast::FnSig,
|
|
|
|
|
x: &AutoDiffAttrs,
|
|
|
|
|
span: Span,
|
|
|
|
|
) -> (ast::FnSig, Vec<String>, Vec<Ident>, bool) {
|
|
|
|
|
let dcx = ecx.sess.dcx();
|
|
|
|
|
let has_ret = has_ret(&sig.decl.output);
|
|
|
|
|
let sig_args = sig.decl.inputs.len() + if has_ret { 1 } else { 0 };
|
|
|
|
|
let num_activities = x.input_activity.len() + if x.has_ret_activity() { 1 } else { 0 };
|
|
|
|
|
if sig_args != num_activities {
|
|
|
|
|
dcx.emit_err(errors::AutoDiffInvalidNumberActivities {
|
|
|
|
|
span,
|
|
|
|
|
expected: sig_args,
|
|
|
|
|
found: num_activities,
|
|
|
|
|
});
|
|
|
|
|
// This is not the right signature, but we can continue parsing.
|
|
|
|
|
return (sig.clone(), vec![], vec![], true);
|
|
|
|
|
}
|
|
|
|
|
assert!(sig.decl.inputs.len() == x.input_activity.len());
|
|
|
|
|
assert!(has_ret == x.has_ret_activity());
|
|
|
|
|
let mut d_decl = sig.decl.clone();
|
|
|
|
|
let mut d_inputs = Vec::new();
|
|
|
|
|
let mut new_inputs = Vec::new();
|
|
|
|
|
let mut idents = Vec::new();
|
|
|
|
|
let mut act_ret = ThinVec::new();
|
|
|
|
|
|
|
|
|
|
// We have two loops, a first one just to check the activities and types and possibly report
|
|
|
|
|
// multiple errors in one compilation session.
|
|
|
|
|
let mut errors = false;
|
|
|
|
|
for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) {
|
|
|
|
|
if !valid_input_activity(x.mode, *activity) {
|
|
|
|
|
dcx.emit_err(errors::AutoDiffInvalidApplicationModeAct {
|
|
|
|
|
span,
|
|
|
|
|
mode: x.mode.to_string(),
|
|
|
|
|
act: activity.to_string(),
|
|
|
|
|
});
|
|
|
|
|
errors = true;
|
|
|
|
|
}
|
|
|
|
|
if !valid_ty_for_activity(&arg.ty, *activity) {
|
|
|
|
|
dcx.emit_err(errors::AutoDiffInvalidTypeForActivity {
|
|
|
|
|
span: arg.ty.span,
|
|
|
|
|
act: activity.to_string(),
|
|
|
|
|
});
|
|
|
|
|
errors = true;
|
|
|
|
|
}
|
|
|
|
|
}
|
2025-03-07 17:37:50 +01:00
|
|
|
|
|
|
|
|
if has_ret && !valid_ret_activity(x.mode, x.ret_activity) {
|
|
|
|
|
dcx.emit_err(errors::AutoDiffInvalidRetAct {
|
|
|
|
|
span,
|
|
|
|
|
mode: x.mode.to_string(),
|
|
|
|
|
act: x.ret_activity.to_string(),
|
|
|
|
|
});
|
|
|
|
|
// We don't set `errors = true` to avoid annoying type errors relative
|
|
|
|
|
// to the expanded macro type signature
|
|
|
|
|
}
|
|
|
|
|
|
2024-10-11 19:13:31 +02:00
|
|
|
if errors {
|
|
|
|
|
// This is not the right signature, but we can continue parsing.
|
|
|
|
|
return (sig.clone(), new_inputs, idents, true);
|
|
|
|
|
}
|
2025-03-07 17:37:50 +01:00
|
|
|
|
2024-10-11 19:13:31 +02:00
|
|
|
let unsafe_activities = x
|
|
|
|
|
.input_activity
|
|
|
|
|
.iter()
|
|
|
|
|
.any(|&act| matches!(act, DiffActivity::DuplicatedOnly | DiffActivity::DualOnly));
|
|
|
|
|
for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) {
|
|
|
|
|
d_inputs.push(arg.clone());
|
|
|
|
|
match activity {
|
|
|
|
|
DiffActivity::Active => {
|
|
|
|
|
act_ret.push(arg.ty.clone());
|
2025-04-03 17:19:11 -04:00
|
|
|
// if width =/= 1, then push [arg.ty; width] to act_ret
|
2024-10-11 19:13:31 +02:00
|
|
|
}
|
|
|
|
|
DiffActivity::ActiveOnly => {
|
|
|
|
|
// We will add the active scalar to the return type.
|
|
|
|
|
// This is handled later.
|
|
|
|
|
}
|
|
|
|
|
DiffActivity::Duplicated | DiffActivity::DuplicatedOnly => {
|
2025-04-03 17:19:11 -04:00
|
|
|
for i in 0..x.width {
|
|
|
|
|
let mut shadow_arg = arg.clone();
|
|
|
|
|
// We += into the shadow in reverse mode.
|
|
|
|
|
shadow_arg.ty = P(assure_mut_ref(&arg.ty));
|
|
|
|
|
let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
|
|
|
|
|
ident.name
|
|
|
|
|
} else {
|
|
|
|
|
debug!("{:#?}", &shadow_arg.pat);
|
|
|
|
|
panic!("not an ident?");
|
|
|
|
|
};
|
|
|
|
|
let name: String = format!("d{}_{}", old_name, i);
|
|
|
|
|
new_inputs.push(name.clone());
|
|
|
|
|
let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
|
|
|
|
|
shadow_arg.pat = P(ast::Pat {
|
|
|
|
|
id: ast::DUMMY_NODE_ID,
|
|
|
|
|
kind: PatKind::Ident(BindingMode::NONE, ident, None),
|
|
|
|
|
span: shadow_arg.pat.span,
|
|
|
|
|
tokens: shadow_arg.pat.tokens.clone(),
|
|
|
|
|
});
|
|
|
|
|
d_inputs.push(shadow_arg.clone());
|
|
|
|
|
}
|
2024-10-11 19:13:31 +02:00
|
|
|
}
|
2025-04-05 03:10:19 -04:00
|
|
|
DiffActivity::Dual
|
|
|
|
|
| DiffActivity::DualOnly
|
|
|
|
|
| DiffActivity::Dualv
|
|
|
|
|
| DiffActivity::DualvOnly => {
|
|
|
|
|
// the *v variants get lowered to enzyme_dupv and enzyme_dupnoneedv, which cause
|
|
|
|
|
// Enzyme to not expect N arguments, but one argument (which is instead larger).
|
|
|
|
|
let iterations =
|
|
|
|
|
if matches!(activity, DiffActivity::Dualv | DiffActivity::DualvOnly) {
|
|
|
|
|
1
|
|
|
|
|
} else {
|
|
|
|
|
x.width
|
|
|
|
|
};
|
|
|
|
|
for i in 0..iterations {
|
2025-04-03 17:19:11 -04:00
|
|
|
let mut shadow_arg = arg.clone();
|
|
|
|
|
let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
|
|
|
|
|
ident.name
|
|
|
|
|
} else {
|
|
|
|
|
debug!("{:#?}", &shadow_arg.pat);
|
|
|
|
|
panic!("not an ident?");
|
|
|
|
|
};
|
|
|
|
|
let name: String = format!("b{}_{}", old_name, i);
|
|
|
|
|
new_inputs.push(name.clone());
|
|
|
|
|
let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
|
|
|
|
|
shadow_arg.pat = P(ast::Pat {
|
|
|
|
|
id: ast::DUMMY_NODE_ID,
|
|
|
|
|
kind: PatKind::Ident(BindingMode::NONE, ident, None),
|
|
|
|
|
span: shadow_arg.pat.span,
|
|
|
|
|
tokens: shadow_arg.pat.tokens.clone(),
|
|
|
|
|
});
|
|
|
|
|
d_inputs.push(shadow_arg.clone());
|
|
|
|
|
}
|
2024-10-11 19:13:31 +02:00
|
|
|
}
|
|
|
|
|
DiffActivity::Const => {
|
|
|
|
|
// Nothing to do here.
|
|
|
|
|
}
|
2025-04-05 03:10:19 -04:00
|
|
|
DiffActivity::None | DiffActivity::FakeActivitySize(_) => {
|
2024-10-11 19:13:31 +02:00
|
|
|
panic!("Should not happen");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if let PatKind::Ident(_, ident, _) = arg.pat.kind {
|
|
|
|
|
idents.push(ident.clone());
|
|
|
|
|
} else {
|
|
|
|
|
panic!("not an ident?");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let active_only_ret = x.ret_activity == DiffActivity::ActiveOnly;
|
|
|
|
|
if active_only_ret {
|
|
|
|
|
assert!(x.mode.is_rev());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// If we return a scalar in the primal and the scalar is active,
|
|
|
|
|
// then add it as last arg to the inputs.
|
|
|
|
|
if x.mode.is_rev() {
|
|
|
|
|
match x.ret_activity {
|
|
|
|
|
DiffActivity::Active | DiffActivity::ActiveOnly => {
|
|
|
|
|
let ty = match d_decl.output {
|
|
|
|
|
FnRetTy::Ty(ref ty) => ty.clone(),
|
|
|
|
|
FnRetTy::Default(span) => {
|
|
|
|
|
panic!("Did not expect Default ret ty: {:?}", span);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
let name = "dret".to_string();
|
|
|
|
|
let ident = Ident::from_str_and_span(&name, ty.span);
|
|
|
|
|
let shadow_arg = ast::Param {
|
|
|
|
|
attrs: ThinVec::new(),
|
|
|
|
|
ty: ty.clone(),
|
|
|
|
|
pat: P(ast::Pat {
|
|
|
|
|
id: ast::DUMMY_NODE_ID,
|
|
|
|
|
kind: PatKind::Ident(BindingMode::NONE, ident, None),
|
|
|
|
|
span: ty.span,
|
|
|
|
|
tokens: None,
|
|
|
|
|
}),
|
|
|
|
|
id: ast::DUMMY_NODE_ID,
|
|
|
|
|
span: ty.span,
|
|
|
|
|
is_placeholder: false,
|
|
|
|
|
};
|
|
|
|
|
d_inputs.push(shadow_arg);
|
|
|
|
|
new_inputs.push(name);
|
|
|
|
|
}
|
|
|
|
|
_ => {}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
d_decl.inputs = d_inputs.into();
|
|
|
|
|
|
|
|
|
|
if x.mode.is_fwd() {
|
2025-04-03 17:19:11 -04:00
|
|
|
let ty = match d_decl.output {
|
|
|
|
|
FnRetTy::Ty(ref ty) => ty.clone(),
|
|
|
|
|
FnRetTy::Default(span) => {
|
|
|
|
|
// We want to return std::hint::black_box(()).
|
|
|
|
|
let kind = TyKind::Tup(ThinVec::new());
|
|
|
|
|
let ty = P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None });
|
|
|
|
|
d_decl.output = FnRetTy::Ty(ty.clone());
|
|
|
|
|
assert!(matches!(x.ret_activity, DiffActivity::None));
|
|
|
|
|
// this won't be used below, so any type would be fine.
|
|
|
|
|
ty
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
2025-04-05 03:10:19 -04:00
|
|
|
if matches!(x.ret_activity, DiffActivity::Dual | DiffActivity::Dualv) {
|
|
|
|
|
let kind = if x.width == 1 || matches!(x.ret_activity, DiffActivity::Dualv) {
|
2025-04-03 17:19:11 -04:00
|
|
|
// Dual can only be used for f32/f64 ret.
|
|
|
|
|
// In that case we return now a tuple with two floats.
|
|
|
|
|
TyKind::Tup(thin_vec![ty.clone(), ty.clone()])
|
|
|
|
|
} else {
|
|
|
|
|
// We have to return [T; width+1], +1 for the primal return.
|
|
|
|
|
let anon_const = rustc_ast::AnonConst {
|
|
|
|
|
id: ast::DUMMY_NODE_ID,
|
|
|
|
|
value: ecx.expr_usize(span, 1 + x.width as usize),
|
|
|
|
|
};
|
|
|
|
|
TyKind::Array(ty.clone(), anon_const)
|
2024-10-11 19:13:31 +02:00
|
|
|
};
|
|
|
|
|
let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
|
|
|
|
|
d_decl.output = FnRetTy::Ty(ty);
|
|
|
|
|
}
|
2025-04-05 03:10:19 -04:00
|
|
|
if matches!(x.ret_activity, DiffActivity::DualOnly | DiffActivity::DualvOnly) {
|
2024-10-11 19:13:31 +02:00
|
|
|
// No need to change the return type,
|
2025-04-03 17:19:11 -04:00
|
|
|
// we will just return the shadow in place of the primal return.
|
|
|
|
|
// However, if we have a width > 1, then we don't return -> T, but -> [T; width]
|
|
|
|
|
if x.width > 1 {
|
|
|
|
|
let anon_const = rustc_ast::AnonConst {
|
|
|
|
|
id: ast::DUMMY_NODE_ID,
|
|
|
|
|
value: ecx.expr_usize(span, x.width as usize),
|
|
|
|
|
};
|
|
|
|
|
let kind = TyKind::Array(ty.clone(), anon_const);
|
|
|
|
|
let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
|
|
|
|
|
d_decl.output = FnRetTy::Ty(ty);
|
|
|
|
|
}
|
2024-10-11 19:13:31 +02:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// If we use ActiveOnly, drop the original return value.
|
|
|
|
|
d_decl.output =
|
|
|
|
|
if active_only_ret { FnRetTy::Default(span) } else { d_decl.output.clone() };
|
|
|
|
|
|
|
|
|
|
trace!("act_ret: {:?}", act_ret);
|
|
|
|
|
|
|
|
|
|
// If we have an active input scalar, add it's gradient to the
|
|
|
|
|
// return type. This might require changing the return type to a
|
|
|
|
|
// tuple.
|
|
|
|
|
if act_ret.len() > 0 {
|
|
|
|
|
let ret_ty = match d_decl.output {
|
|
|
|
|
FnRetTy::Ty(ref ty) => {
|
|
|
|
|
if !active_only_ret {
|
|
|
|
|
act_ret.insert(0, ty.clone());
|
|
|
|
|
}
|
|
|
|
|
let kind = TyKind::Tup(act_ret);
|
|
|
|
|
P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None })
|
|
|
|
|
}
|
|
|
|
|
FnRetTy::Default(span) => {
|
|
|
|
|
if act_ret.len() == 1 {
|
|
|
|
|
act_ret[0].clone()
|
|
|
|
|
} else {
|
|
|
|
|
let kind = TyKind::Tup(act_ret.iter().map(|arg| arg.clone()).collect());
|
|
|
|
|
P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None })
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
d_decl.output = FnRetTy::Ty(ret_ty);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let mut d_header = sig.header.clone();
|
|
|
|
|
if unsafe_activities {
|
|
|
|
|
d_header.safety = rustc_ast::Safety::Unsafe(span);
|
|
|
|
|
}
|
|
|
|
|
let d_sig = FnSig { header: d_header, decl: d_decl, span };
|
|
|
|
|
trace!("Generated signature: {:?}", d_sig);
|
|
|
|
|
(d_sig, new_inputs, idents, false)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2025-05-06 09:19:33 +02:00
|
|
|
pub(crate) use llvm_enzyme::{expand_forward, expand_reverse};
|