Rollup merge of #139351 - EnzymeAD:autodiff-batching2, r=oli-obk
Autodiff batching2 ~I will rebase it once my first PR landed.~ done. This autodiff batch mode is more similar to scalar autodiff, since it still only takes one shadow argument. However, that argument is supposed to be `width` times larger. r? `@oli-obk` Tracking: - https://github.com/rust-lang/rust/issues/124509
This commit is contained in:
@@ -799,8 +799,19 @@ mod llvm_enzyme {
|
||||
d_inputs.push(shadow_arg.clone());
|
||||
}
|
||||
}
|
||||
DiffActivity::Dual | DiffActivity::DualOnly => {
|
||||
for i in 0..x.width {
|
||||
DiffActivity::Dual
|
||||
| DiffActivity::DualOnly
|
||||
| DiffActivity::Dualv
|
||||
| DiffActivity::DualvOnly => {
|
||||
// the *v variants get lowered to enzyme_dupv and enzyme_dupnoneedv, which cause
|
||||
// Enzyme to not expect N arguments, but one argument (which is instead larger).
|
||||
let iterations =
|
||||
if matches!(activity, DiffActivity::Dualv | DiffActivity::DualvOnly) {
|
||||
1
|
||||
} else {
|
||||
x.width
|
||||
};
|
||||
for i in 0..iterations {
|
||||
let mut shadow_arg = arg.clone();
|
||||
let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
|
||||
ident.name
|
||||
@@ -823,7 +834,7 @@ mod llvm_enzyme {
|
||||
DiffActivity::Const => {
|
||||
// Nothing to do here.
|
||||
}
|
||||
DiffActivity::None | DiffActivity::FakeActivitySize => {
|
||||
DiffActivity::None | DiffActivity::FakeActivitySize(_) => {
|
||||
panic!("Should not happen");
|
||||
}
|
||||
}
|
||||
@@ -887,8 +898,8 @@ mod llvm_enzyme {
|
||||
}
|
||||
};
|
||||
|
||||
if let DiffActivity::Dual = x.ret_activity {
|
||||
let kind = if x.width == 1 {
|
||||
if matches!(x.ret_activity, DiffActivity::Dual | DiffActivity::Dualv) {
|
||||
let kind = if x.width == 1 || matches!(x.ret_activity, DiffActivity::Dualv) {
|
||||
// Dual can only be used for f32/f64 ret.
|
||||
// In that case we return now a tuple with two floats.
|
||||
TyKind::Tup(thin_vec![ty.clone(), ty.clone()])
|
||||
@@ -903,7 +914,7 @@ mod llvm_enzyme {
|
||||
let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
|
||||
d_decl.output = FnRetTy::Ty(ty);
|
||||
}
|
||||
if let DiffActivity::DualOnly = x.ret_activity {
|
||||
if matches!(x.ret_activity, DiffActivity::DualOnly | DiffActivity::DualvOnly) {
|
||||
// No need to change the return type,
|
||||
// we will just return the shadow in place of the primal return.
|
||||
// However, if we have a width > 1, then we don't return -> T, but -> [T; width]
|
||||
|
||||
Reference in New Issue
Block a user