Initial naive implementation using Symbols to represent autodiff modes (Forward, Reverse)
Since the mode is no longer part of `meta_item`, we must insert it manually (otherwise macro expansion with `#[rustc_autodiff]` won't work). This can be revised later if a more structured representation becomes necessary (using enums, annotated structs, etc). Some tests are currently failing. I'll address them next.
This commit is contained in:
@@ -259,29 +259,41 @@ mod llvm_enzyme {
|
||||
// create TokenStream from vec elemtents:
|
||||
// meta_item doesn't have a .tokens field
|
||||
let mut ts: Vec<TokenTree> = vec![];
|
||||
if meta_item_vec.len() < 2 {
|
||||
// At the bare minimum, we need a fnc name and a mode, even for a dummy function with no
|
||||
// input and output args.
|
||||
if meta_item_vec.len() < 1 {
|
||||
// At the bare minimum, we need a fnc name.
|
||||
dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
|
||||
return vec![item];
|
||||
}
|
||||
|
||||
meta_item_inner_to_ts(&meta_item_vec[1], &mut ts);
|
||||
let mode_symbol = match mode {
|
||||
DiffMode::Forward => sym::Forward,
|
||||
DiffMode::Reverse => sym::Reverse,
|
||||
_ => unreachable!("Unsupported mode: {:?}", mode),
|
||||
};
|
||||
|
||||
// Insert mode token
|
||||
let mode_token = Token::new(TokenKind::Ident(mode_symbol, false.into()), Span::default());
|
||||
ts.insert(0, TokenTree::Token(mode_token, Spacing::Joint));
|
||||
ts.insert(
|
||||
1,
|
||||
TokenTree::Token(Token::new(TokenKind::Comma, Span::default()), Spacing::Alone),
|
||||
);
|
||||
|
||||
// Now, if the user gave a width (vector aka batch-mode ad), then we copy it.
|
||||
// If it is not given, we default to 1 (scalar mode).
|
||||
let start_position;
|
||||
let kind: LitKind = LitKind::Integer;
|
||||
let symbol;
|
||||
if meta_item_vec.len() >= 3
|
||||
&& let Some(width) = width(&meta_item_vec[2])
|
||||
if meta_item_vec.len() >= 2
|
||||
&& let Some(width) = width(&meta_item_vec[1])
|
||||
{
|
||||
start_position = 3;
|
||||
start_position = 2;
|
||||
symbol = Symbol::intern(&width.to_string());
|
||||
} else {
|
||||
start_position = 2;
|
||||
start_position = 1;
|
||||
symbol = sym::integer(1);
|
||||
}
|
||||
|
||||
let l: Lit = Lit { kind, symbol, suffix: None };
|
||||
let t = Token::new(TokenKind::Literal(l), Span::default());
|
||||
let comma = Token::new(TokenKind::Comma, Span::default());
|
||||
|
||||
@@ -253,6 +253,7 @@ symbols! {
|
||||
FnMut,
|
||||
FnOnce,
|
||||
Formatter,
|
||||
Forward,
|
||||
From,
|
||||
FromIterator,
|
||||
FromResidual,
|
||||
@@ -348,6 +349,7 @@ symbols! {
|
||||
Result,
|
||||
ResumeTy,
|
||||
Return,
|
||||
Reverse,
|
||||
Right,
|
||||
Rust,
|
||||
RustaceansAreAwesome,
|
||||
|
||||
Reference in New Issue
Block a user