Split autodiff into autodiff_forward and autodiff_reverse

Pending fix.
```
error: cannot find a built-in macro with name `autodiff_forward`
    --> library\core\src\macros\mod.rs:1542:5
     |
1542 | /     pub macro autodiff_forward($item:item) {
1543 | |         /* compiler built-in */
1544 | |     }
     | |_____^

error: cannot find a built-in macro with name `autodiff_reverse`
    --> library\core\src\macros\mod.rs:1549:5
     |
1549 | /     pub macro autodiff_reverse($item:item) {
1550 | |         /* compiler built-in */
1551 | |     }
     | |_____^

error: could not compile `core` (lib) due to 2 previous errors
```
This commit is contained in:
Marcelo Domínguez
2025-05-06 09:19:33 +02:00
committed by Marcelo Domínguez
parent f8e9e7636a
commit b21c9e7bfb
6 changed files with 48 additions and 17 deletions

View File

@@ -88,25 +88,20 @@ mod llvm_enzyme {
has_ret: bool, has_ret: bool,
) -> AutoDiffAttrs { ) -> AutoDiffAttrs {
let dcx = ecx.sess.dcx(); let dcx = ecx.sess.dcx();
let mode = name(&meta_item[1]);
let Ok(mode) = DiffMode::from_str(&mode) else {
dcx.emit_err(errors::AutoDiffInvalidMode { span: meta_item[1].span(), mode });
return AutoDiffAttrs::error();
};
// Now we check, whether the user wants autodiff in batch/vector mode, or scalar mode. // Now we check, whether the user wants autodiff in batch/vector mode, or scalar mode.
// If he doesn't specify an integer (=width), we default to scalar mode, thus width=1. // If he doesn't specify an integer (=width), we default to scalar mode, thus width=1.
let mut first_activity = 2; let mut first_activity = 1;
let width = if let [_, _, x, ..] = &meta_item[..] let width = if let [_, x, ..] = &meta_item[..]
&& let Some(x) = width(x) && let Some(x) = width(x)
{ {
first_activity = 3; first_activity = 2;
match x.try_into() { match x.try_into() {
Ok(x) => x, Ok(x) => x,
Err(_) => { Err(_) => {
dcx.emit_err(errors::AutoDiffInvalidWidth { dcx.emit_err(errors::AutoDiffInvalidWidth {
span: meta_item[2].span(), span: meta_item[1].span(),
width: x, width: x,
}); });
return AutoDiffAttrs::error(); return AutoDiffAttrs::error();
@@ -150,7 +145,7 @@ mod llvm_enzyme {
}; };
AutoDiffAttrs { AutoDiffAttrs {
mode, mode: DiffMode::Error,
width, width,
ret_activity: *ret_activity, ret_activity: *ret_activity,
input_activity: input_activity.to_vec(), input_activity: input_activity.to_vec(),
@@ -165,6 +160,24 @@ mod llvm_enzyme {
ts.push(TokenTree::Token(comma.clone(), Spacing::Alone)); ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
} }
pub(crate) fn expand_forward(
ecx: &mut ExtCtxt<'_>,
expand_span: Span,
meta_item: &ast::MetaItem,
item: Annotatable,
) -> Vec<Annotatable> {
expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Forward)
}
pub(crate) fn expand_reverse(
ecx: &mut ExtCtxt<'_>,
expand_span: Span,
meta_item: &ast::MetaItem,
item: Annotatable,
) -> Vec<Annotatable> {
expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Reverse)
}
/// We expand the autodiff macro to generate a new placeholder function which passes /// We expand the autodiff macro to generate a new placeholder function which passes
/// type-checking and can be called by users. The function body of the placeholder function will /// type-checking and can be called by users. The function body of the placeholder function will
/// later be replaced on LLVM-IR level, so the design of the body is less important and for now /// later be replaced on LLVM-IR level, so the design of the body is less important and for now
@@ -198,11 +211,12 @@ mod llvm_enzyme {
/// ``` /// ```
/// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked /// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked
/// in CI. /// in CI.
pub(crate) fn expand( pub(crate) fn expand_with_mode(
ecx: &mut ExtCtxt<'_>, ecx: &mut ExtCtxt<'_>,
expand_span: Span, expand_span: Span,
meta_item: &ast::MetaItem, meta_item: &ast::MetaItem,
mut item: Annotatable, mut item: Annotatable,
mode: DiffMode,
) -> Vec<Annotatable> { ) -> Vec<Annotatable> {
if cfg!(not(llvm_enzyme)) { if cfg!(not(llvm_enzyme)) {
ecx.sess.dcx().emit_err(errors::AutoDiffSupportNotBuild { span: meta_item.span }); ecx.sess.dcx().emit_err(errors::AutoDiffSupportNotBuild { span: meta_item.span });
@@ -289,7 +303,8 @@ mod llvm_enzyme {
ts.pop(); ts.pop();
let ts: TokenStream = TokenStream::from_iter(ts); let ts: TokenStream = TokenStream::from_iter(ts);
let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret); let mut x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret);
x.mode = mode;
if !x.is_active() { if !x.is_active() {
// We encountered an error, so we return the original item. // We encountered an error, so we return the original item.
// This allows us to potentially parse other attributes. // This allows us to potentially parse other attributes.
@@ -1017,4 +1032,4 @@ mod llvm_enzyme {
} }
} }
pub(crate) use llvm_enzyme::expand; pub(crate) use llvm_enzyme::{expand_forward, expand_reverse};

View File

@@ -112,7 +112,8 @@ pub fn register_builtin_macros(resolver: &mut dyn ResolverExpand) {
register_attr! { register_attr! {
alloc_error_handler: alloc_error_handler::expand, alloc_error_handler: alloc_error_handler::expand,
autodiff: autodiff::expand, autodiff_forward: autodiff::expand_forward,
autodiff_reverse: autodiff::expand_reverse,
bench: test::expand_bench, bench: test::expand_bench,
cfg_accessible: cfg_accessible::Expander, cfg_accessible: cfg_accessible::Expander,
cfg_eval: cfg_eval::expand, cfg_eval: cfg_eval::expand,

View File

@@ -255,7 +255,7 @@ impl<'tcx> CheckAttrVisitor<'tcx> {
self.check_generic_attr(hir_id, attr, target, Target::Fn); self.check_generic_attr(hir_id, attr, target, Target::Fn);
self.check_proc_macro(hir_id, target, ProcMacroKind::Derive) self.check_proc_macro(hir_id, target, ProcMacroKind::Derive)
} }
[sym::autodiff, ..] => { [sym::autodiff_forward, ..] | [sym::autodiff_reverse, ..] => {
self.check_autodiff(hir_id, attr, span, target) self.check_autodiff(hir_id, attr, span, target)
} }
[sym::coroutine, ..] => { [sym::coroutine, ..] => {

View File

@@ -531,7 +531,8 @@ symbols! {
audit_that, audit_that,
augmented_assignments, augmented_assignments,
auto_traits, auto_traits,
autodiff, autodiff_forward,
autodiff_reverse,
automatically_derived, automatically_derived,
avx, avx,
avx10_target_feature, avx10_target_feature,

View File

@@ -229,7 +229,7 @@ pub mod assert_matches {
/// Unstable module containing the unstable `autodiff` macro. /// Unstable module containing the unstable `autodiff` macro.
pub mod autodiff { pub mod autodiff {
#[unstable(feature = "autodiff", issue = "124509")] #[unstable(feature = "autodiff", issue = "124509")]
pub use crate::macros::builtin::autodiff; pub use crate::macros::builtin::{autodiff_forward, autodiff_reverse};
} }
#[unstable(feature = "contracts", issue = "128044")] #[unstable(feature = "contracts", issue = "128044")]

View File

@@ -1536,6 +1536,20 @@ pub(crate) mod builtin {
/* compiler built-in */ /* compiler built-in */
} }
#[unstable(feature = "autodiff", issue = "124509")]
#[allow_internal_unstable(rustc_attrs)]
#[rustc_builtin_macro]
pub macro autodiff_forward($item:item) {
/* compiler built-in */
}
#[unstable(feature = "autodiff", issue = "124509")]
#[allow_internal_unstable(rustc_attrs)]
#[rustc_builtin_macro]
pub macro autodiff_reverse($item:item) {
/* compiler built-in */
}
/// Asserts that a boolean expression is `true` at runtime. /// Asserts that a boolean expression is `true` at runtime.
/// ///
/// This will invoke the [`panic!`] macro if the provided expression cannot be /// This will invoke the [`panic!`] macro if the provided expression cannot be