Initial naive implementation using Symbols to represent autodiff modes (Forward, Reverse)
Since the mode is no longer part of `meta_item`, we must insert it manually (otherwise macro expansion with `#[rustc_autodiff]` won't work). This can be revised later if a more structured representation becomes necessary (using enums, annotated structs, etc). Some tests are currently failing. I'll address them next.
This commit is contained in:
@@ -259,29 +259,41 @@ mod llvm_enzyme {
|
|||||||
// create TokenStream from vec elemtents:
|
// create TokenStream from vec elemtents:
|
||||||
// meta_item doesn't have a .tokens field
|
// meta_item doesn't have a .tokens field
|
||||||
let mut ts: Vec<TokenTree> = vec![];
|
let mut ts: Vec<TokenTree> = vec![];
|
||||||
if meta_item_vec.len() < 2 {
|
if meta_item_vec.len() < 1 {
|
||||||
// At the bare minimum, we need a fnc name and a mode, even for a dummy function with no
|
// At the bare minimum, we need a fnc name.
|
||||||
// input and output args.
|
|
||||||
dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
|
dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
|
||||||
return vec![item];
|
return vec![item];
|
||||||
}
|
}
|
||||||
|
|
||||||
meta_item_inner_to_ts(&meta_item_vec[1], &mut ts);
|
let mode_symbol = match mode {
|
||||||
|
DiffMode::Forward => sym::Forward,
|
||||||
|
DiffMode::Reverse => sym::Reverse,
|
||||||
|
_ => unreachable!("Unsupported mode: {:?}", mode),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Insert mode token
|
||||||
|
let mode_token = Token::new(TokenKind::Ident(mode_symbol, false.into()), Span::default());
|
||||||
|
ts.insert(0, TokenTree::Token(mode_token, Spacing::Joint));
|
||||||
|
ts.insert(
|
||||||
|
1,
|
||||||
|
TokenTree::Token(Token::new(TokenKind::Comma, Span::default()), Spacing::Alone),
|
||||||
|
);
|
||||||
|
|
||||||
// Now, if the user gave a width (vector aka batch-mode ad), then we copy it.
|
// Now, if the user gave a width (vector aka batch-mode ad), then we copy it.
|
||||||
// If it is not given, we default to 1 (scalar mode).
|
// If it is not given, we default to 1 (scalar mode).
|
||||||
let start_position;
|
let start_position;
|
||||||
let kind: LitKind = LitKind::Integer;
|
let kind: LitKind = LitKind::Integer;
|
||||||
let symbol;
|
let symbol;
|
||||||
if meta_item_vec.len() >= 3
|
if meta_item_vec.len() >= 2
|
||||||
&& let Some(width) = width(&meta_item_vec[2])
|
&& let Some(width) = width(&meta_item_vec[1])
|
||||||
{
|
{
|
||||||
start_position = 3;
|
start_position = 2;
|
||||||
symbol = Symbol::intern(&width.to_string());
|
symbol = Symbol::intern(&width.to_string());
|
||||||
} else {
|
} else {
|
||||||
start_position = 2;
|
start_position = 1;
|
||||||
symbol = sym::integer(1);
|
symbol = sym::integer(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
let l: Lit = Lit { kind, symbol, suffix: None };
|
let l: Lit = Lit { kind, symbol, suffix: None };
|
||||||
let t = Token::new(TokenKind::Literal(l), Span::default());
|
let t = Token::new(TokenKind::Literal(l), Span::default());
|
||||||
let comma = Token::new(TokenKind::Comma, Span::default());
|
let comma = Token::new(TokenKind::Comma, Span::default());
|
||||||
|
|||||||
@@ -253,6 +253,7 @@ symbols! {
|
|||||||
FnMut,
|
FnMut,
|
||||||
FnOnce,
|
FnOnce,
|
||||||
Formatter,
|
Formatter,
|
||||||
|
Forward,
|
||||||
From,
|
From,
|
||||||
FromIterator,
|
FromIterator,
|
||||||
FromResidual,
|
FromResidual,
|
||||||
@@ -348,6 +349,7 @@ symbols! {
|
|||||||
Result,
|
Result,
|
||||||
ResumeTy,
|
ResumeTy,
|
||||||
Return,
|
Return,
|
||||||
|
Reverse,
|
||||||
Right,
|
Right,
|
||||||
Rust,
|
Rust,
|
||||||
RustaceansAreAwesome,
|
RustaceansAreAwesome,
|
||||||
|
|||||||
Reference in New Issue
Block a user