Split autodiff into autodiff_forward and autodiff_reverse
Pending fix.
```
error: cannot find a built-in macro with name `autodiff_forward`
--> library\core\src\macros\mod.rs:1542:5
|
1542 | / pub macro autodiff_forward($item:item) {
1543 | | /* compiler built-in */
1544 | | }
| |_____^
error: cannot find a built-in macro with name `autodiff_reverse`
--> library\core\src\macros\mod.rs:1549:5
|
1549 | / pub macro autodiff_reverse($item:item) {
1550 | | /* compiler built-in */
1551 | | }
| |_____^
error: could not compile `core` (lib) due to 2 previous errors
```
This commit is contained in:
committed by
Marcelo Domínguez
parent
f8e9e7636a
commit
b21c9e7bfb
@@ -88,25 +88,20 @@ mod llvm_enzyme {
|
|||||||
has_ret: bool,
|
has_ret: bool,
|
||||||
) -> AutoDiffAttrs {
|
) -> AutoDiffAttrs {
|
||||||
let dcx = ecx.sess.dcx();
|
let dcx = ecx.sess.dcx();
|
||||||
let mode = name(&meta_item[1]);
|
|
||||||
let Ok(mode) = DiffMode::from_str(&mode) else {
|
|
||||||
dcx.emit_err(errors::AutoDiffInvalidMode { span: meta_item[1].span(), mode });
|
|
||||||
return AutoDiffAttrs::error();
|
|
||||||
};
|
|
||||||
|
|
||||||
// Now we check, whether the user wants autodiff in batch/vector mode, or scalar mode.
|
// Now we check, whether the user wants autodiff in batch/vector mode, or scalar mode.
|
||||||
// If he doesn't specify an integer (=width), we default to scalar mode, thus width=1.
|
// If he doesn't specify an integer (=width), we default to scalar mode, thus width=1.
|
||||||
let mut first_activity = 2;
|
let mut first_activity = 1;
|
||||||
|
|
||||||
let width = if let [_, _, x, ..] = &meta_item[..]
|
let width = if let [_, x, ..] = &meta_item[..]
|
||||||
&& let Some(x) = width(x)
|
&& let Some(x) = width(x)
|
||||||
{
|
{
|
||||||
first_activity = 3;
|
first_activity = 2;
|
||||||
match x.try_into() {
|
match x.try_into() {
|
||||||
Ok(x) => x,
|
Ok(x) => x,
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
dcx.emit_err(errors::AutoDiffInvalidWidth {
|
dcx.emit_err(errors::AutoDiffInvalidWidth {
|
||||||
span: meta_item[2].span(),
|
span: meta_item[1].span(),
|
||||||
width: x,
|
width: x,
|
||||||
});
|
});
|
||||||
return AutoDiffAttrs::error();
|
return AutoDiffAttrs::error();
|
||||||
@@ -150,7 +145,7 @@ mod llvm_enzyme {
|
|||||||
};
|
};
|
||||||
|
|
||||||
AutoDiffAttrs {
|
AutoDiffAttrs {
|
||||||
mode,
|
mode: DiffMode::Error,
|
||||||
width,
|
width,
|
||||||
ret_activity: *ret_activity,
|
ret_activity: *ret_activity,
|
||||||
input_activity: input_activity.to_vec(),
|
input_activity: input_activity.to_vec(),
|
||||||
@@ -165,6 +160,24 @@ mod llvm_enzyme {
|
|||||||
ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
|
ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn expand_forward(
|
||||||
|
ecx: &mut ExtCtxt<'_>,
|
||||||
|
expand_span: Span,
|
||||||
|
meta_item: &ast::MetaItem,
|
||||||
|
item: Annotatable,
|
||||||
|
) -> Vec<Annotatable> {
|
||||||
|
expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Forward)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn expand_reverse(
|
||||||
|
ecx: &mut ExtCtxt<'_>,
|
||||||
|
expand_span: Span,
|
||||||
|
meta_item: &ast::MetaItem,
|
||||||
|
item: Annotatable,
|
||||||
|
) -> Vec<Annotatable> {
|
||||||
|
expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Reverse)
|
||||||
|
}
|
||||||
|
|
||||||
/// We expand the autodiff macro to generate a new placeholder function which passes
|
/// We expand the autodiff macro to generate a new placeholder function which passes
|
||||||
/// type-checking and can be called by users. The function body of the placeholder function will
|
/// type-checking and can be called by users. The function body of the placeholder function will
|
||||||
/// later be replaced on LLVM-IR level, so the design of the body is less important and for now
|
/// later be replaced on LLVM-IR level, so the design of the body is less important and for now
|
||||||
@@ -198,11 +211,12 @@ mod llvm_enzyme {
|
|||||||
/// ```
|
/// ```
|
||||||
/// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked
|
/// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked
|
||||||
/// in CI.
|
/// in CI.
|
||||||
pub(crate) fn expand(
|
pub(crate) fn expand_with_mode(
|
||||||
ecx: &mut ExtCtxt<'_>,
|
ecx: &mut ExtCtxt<'_>,
|
||||||
expand_span: Span,
|
expand_span: Span,
|
||||||
meta_item: &ast::MetaItem,
|
meta_item: &ast::MetaItem,
|
||||||
mut item: Annotatable,
|
mut item: Annotatable,
|
||||||
|
mode: DiffMode,
|
||||||
) -> Vec<Annotatable> {
|
) -> Vec<Annotatable> {
|
||||||
if cfg!(not(llvm_enzyme)) {
|
if cfg!(not(llvm_enzyme)) {
|
||||||
ecx.sess.dcx().emit_err(errors::AutoDiffSupportNotBuild { span: meta_item.span });
|
ecx.sess.dcx().emit_err(errors::AutoDiffSupportNotBuild { span: meta_item.span });
|
||||||
@@ -289,7 +303,8 @@ mod llvm_enzyme {
|
|||||||
ts.pop();
|
ts.pop();
|
||||||
let ts: TokenStream = TokenStream::from_iter(ts);
|
let ts: TokenStream = TokenStream::from_iter(ts);
|
||||||
|
|
||||||
let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret);
|
let mut x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret);
|
||||||
|
x.mode = mode;
|
||||||
if !x.is_active() {
|
if !x.is_active() {
|
||||||
// We encountered an error, so we return the original item.
|
// We encountered an error, so we return the original item.
|
||||||
// This allows us to potentially parse other attributes.
|
// This allows us to potentially parse other attributes.
|
||||||
@@ -1017,4 +1032,4 @@ mod llvm_enzyme {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) use llvm_enzyme::expand;
|
pub(crate) use llvm_enzyme::{expand_forward, expand_reverse};
|
||||||
|
|||||||
@@ -112,7 +112,8 @@ pub fn register_builtin_macros(resolver: &mut dyn ResolverExpand) {
|
|||||||
|
|
||||||
register_attr! {
|
register_attr! {
|
||||||
alloc_error_handler: alloc_error_handler::expand,
|
alloc_error_handler: alloc_error_handler::expand,
|
||||||
autodiff: autodiff::expand,
|
autodiff_forward: autodiff::expand_forward,
|
||||||
|
autodiff_reverse: autodiff::expand_reverse,
|
||||||
bench: test::expand_bench,
|
bench: test::expand_bench,
|
||||||
cfg_accessible: cfg_accessible::Expander,
|
cfg_accessible: cfg_accessible::Expander,
|
||||||
cfg_eval: cfg_eval::expand,
|
cfg_eval: cfg_eval::expand,
|
||||||
|
|||||||
@@ -255,7 +255,7 @@ impl<'tcx> CheckAttrVisitor<'tcx> {
|
|||||||
self.check_generic_attr(hir_id, attr, target, Target::Fn);
|
self.check_generic_attr(hir_id, attr, target, Target::Fn);
|
||||||
self.check_proc_macro(hir_id, target, ProcMacroKind::Derive)
|
self.check_proc_macro(hir_id, target, ProcMacroKind::Derive)
|
||||||
}
|
}
|
||||||
[sym::autodiff, ..] => {
|
[sym::autodiff_forward, ..] | [sym::autodiff_reverse, ..] => {
|
||||||
self.check_autodiff(hir_id, attr, span, target)
|
self.check_autodiff(hir_id, attr, span, target)
|
||||||
}
|
}
|
||||||
[sym::coroutine, ..] => {
|
[sym::coroutine, ..] => {
|
||||||
|
|||||||
@@ -531,7 +531,8 @@ symbols! {
|
|||||||
audit_that,
|
audit_that,
|
||||||
augmented_assignments,
|
augmented_assignments,
|
||||||
auto_traits,
|
auto_traits,
|
||||||
autodiff,
|
autodiff_forward,
|
||||||
|
autodiff_reverse,
|
||||||
automatically_derived,
|
automatically_derived,
|
||||||
avx,
|
avx,
|
||||||
avx10_target_feature,
|
avx10_target_feature,
|
||||||
|
|||||||
@@ -229,7 +229,7 @@ pub mod assert_matches {
|
|||||||
/// Unstable module containing the unstable `autodiff` macro.
|
/// Unstable module containing the unstable `autodiff` macro.
|
||||||
pub mod autodiff {
|
pub mod autodiff {
|
||||||
#[unstable(feature = "autodiff", issue = "124509")]
|
#[unstable(feature = "autodiff", issue = "124509")]
|
||||||
pub use crate::macros::builtin::autodiff;
|
pub use crate::macros::builtin::{autodiff_forward, autodiff_reverse};
|
||||||
}
|
}
|
||||||
|
|
||||||
#[unstable(feature = "contracts", issue = "128044")]
|
#[unstable(feature = "contracts", issue = "128044")]
|
||||||
|
|||||||
@@ -1536,6 +1536,20 @@ pub(crate) mod builtin {
|
|||||||
/* compiler built-in */
|
/* compiler built-in */
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[unstable(feature = "autodiff", issue = "124509")]
|
||||||
|
#[allow_internal_unstable(rustc_attrs)]
|
||||||
|
#[rustc_builtin_macro]
|
||||||
|
pub macro autodiff_forward($item:item) {
|
||||||
|
/* compiler built-in */
|
||||||
|
}
|
||||||
|
|
||||||
|
#[unstable(feature = "autodiff", issue = "124509")]
|
||||||
|
#[allow_internal_unstable(rustc_attrs)]
|
||||||
|
#[rustc_builtin_macro]
|
||||||
|
pub macro autodiff_reverse($item:item) {
|
||||||
|
/* compiler built-in */
|
||||||
|
}
|
||||||
|
|
||||||
/// Asserts that a boolean expression is `true` at runtime.
|
/// Asserts that a boolean expression is `true` at runtime.
|
||||||
///
|
///
|
||||||
/// This will invoke the [`panic!`] macro if the provided expression cannot be
|
/// This will invoke the [`panic!`] macro if the provided expression cannot be
|
||||||
|
|||||||
Reference in New Issue
Block a user