Initial naive implementation using Symbols to represent autodiff modes (Forward, Reverse)

Since the mode is no longer part of `meta_item`, we must insert it manually (otherwise macro expansion with `#[rustc_autodiff]` won't work).

This can be revised later if a more structured representation becomes necessary (using enums, annotated structs, etc).

Some tests are currently failing. I'll address them next.
This commit is contained in:
Marcelo Domínguez
2025-05-10 00:52:47 +00:00
parent 2041de7083
commit f92d84cc6e
2 changed files with 22 additions and 8 deletions

View File

@@ -259,29 +259,41 @@ mod llvm_enzyme {
// create TokenStream from vec elemtents:
// meta_item doesn't have a .tokens field
let mut ts: Vec<TokenTree> = vec![];
if meta_item_vec.len() < 2 {
// At the bare minimum, we need a fnc name and a mode, even for a dummy function with no
// input and output args.
if meta_item_vec.len() < 1 {
// At the bare minimum, we need a fnc name.
dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
return vec![item];
}
meta_item_inner_to_ts(&meta_item_vec[1], &mut ts);
let mode_symbol = match mode {
DiffMode::Forward => sym::Forward,
DiffMode::Reverse => sym::Reverse,
_ => unreachable!("Unsupported mode: {:?}", mode),
};
// Insert mode token
let mode_token = Token::new(TokenKind::Ident(mode_symbol, false.into()), Span::default());
ts.insert(0, TokenTree::Token(mode_token, Spacing::Joint));
ts.insert(
1,
TokenTree::Token(Token::new(TokenKind::Comma, Span::default()), Spacing::Alone),
);
// Now, if the user gave a width (vector aka batch-mode ad), then we copy it.
// If it is not given, we default to 1 (scalar mode).
let start_position;
let kind: LitKind = LitKind::Integer;
let symbol;
if meta_item_vec.len() >= 3
&& let Some(width) = width(&meta_item_vec[2])
if meta_item_vec.len() >= 2
&& let Some(width) = width(&meta_item_vec[1])
{
start_position = 3;
start_position = 2;
symbol = Symbol::intern(&width.to_string());
} else {
start_position = 2;
start_position = 1;
symbol = sym::integer(1);
}
let l: Lit = Lit { kind, symbol, suffix: None };
let t = Token::new(TokenKind::Literal(l), Span::default());
let comma = Token::new(TokenKind::Comma, Span::default());

View File

@@ -253,6 +253,7 @@ symbols! {
FnMut,
FnOnce,
Formatter,
Forward,
From,
FromIterator,
FromResidual,
@@ -348,6 +349,7 @@ symbols! {
Result,
ResumeTy,
Return,
Reverse,
Right,
Rust,
RustaceansAreAwesome,