Rollup merge of #139465 - EnzymeAD:autodiff-sret, r=oli-obk
add sret handling for scalar autodiff r? `@oli-obk` Fixing one of the todo's which I left in my previous batching PR. This one handles sret for scalar autodiff. `sret` mostly shows up when we try to return a lot of scalar floats. People often start testing autodiff which toy functions which just use a few scalars as inputs and outputs, and those were the most likely to be affected by this issue. So this fix should make learning/teaching hopefully a bit easier. Tracking: - https://github.com/rust-lang/rust/issues/124509
This commit is contained in:
@@ -92,6 +92,12 @@ pub struct AutoDiffAttrs {
|
||||
pub input_activity: Vec<DiffActivity>,
|
||||
}
|
||||
|
||||
impl AutoDiffAttrs {
|
||||
pub fn has_primal_ret(&self) -> bool {
|
||||
matches!(self.ret_activity, DiffActivity::Active | DiffActivity::Dual)
|
||||
}
|
||||
}
|
||||
|
||||
impl DiffMode {
|
||||
pub fn is_rev(&self) -> bool {
|
||||
matches!(self, DiffMode::Reverse)
|
||||
|
||||
Reference in New Issue
Block a user