Support ZST args
This commit is contained in:
@@ -29,6 +29,7 @@ pub(crate) fn adjust_activity_to_abi<'tcx>(
|
|||||||
|
|
||||||
let mut new_activities = vec![];
|
let mut new_activities = vec![];
|
||||||
let mut new_positions = vec![];
|
let mut new_positions = vec![];
|
||||||
|
let mut del_activities = 0;
|
||||||
for (i, ty) in sig.inputs().iter().enumerate() {
|
for (i, ty) in sig.inputs().iter().enumerate() {
|
||||||
if let Some(inner_ty) = ty.builtin_deref(true) {
|
if let Some(inner_ty) = ty.builtin_deref(true) {
|
||||||
if inner_ty.is_slice() {
|
if inner_ty.is_slice() {
|
||||||
@@ -90,12 +91,20 @@ pub(crate) fn adjust_activity_to_abi<'tcx>(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// For ZST, just ignore and don't add its activity, as this arg won't be present
|
||||||
|
// in the LLVM passed to Enzyme.
|
||||||
|
// FIXME(Sa4dUs): Enforce ZST corresponding diff activity be `Const`
|
||||||
|
if layout.is_zst() {
|
||||||
|
del_activities += 1;
|
||||||
|
da.remove(i);
|
||||||
|
}
|
||||||
|
|
||||||
// If the argument is lowered as a `ScalarPair`, we need to duplicate its activity.
|
// If the argument is lowered as a `ScalarPair`, we need to duplicate its activity.
|
||||||
// Otherwise, the number of activities won't match the number of LLVM arguments and
|
// Otherwise, the number of activities won't match the number of LLVM arguments and
|
||||||
// this will lead to errors when verifying the Enzyme call.
|
// this will lead to errors when verifying the Enzyme call.
|
||||||
if let rustc_abi::BackendRepr::ScalarPair(_, _) = layout.backend_repr() {
|
if let rustc_abi::BackendRepr::ScalarPair(_, _) = layout.backend_repr() {
|
||||||
new_activities.push(da[i].clone());
|
new_activities.push(da[i].clone());
|
||||||
new_positions.push(i + 1);
|
new_positions.push(i + 1 - del_activities);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// now add the extra activities coming from slices
|
// now add the extra activities coming from slices
|
||||||
|
|||||||
17
tests/ui/autodiff/zst.rs
Normal file
17
tests/ui/autodiff/zst.rs
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat
|
||||||
|
//@ no-prefer-dynamic
|
||||||
|
//@ needs-enzyme
|
||||||
|
//@ build-pass
|
||||||
|
|
||||||
|
// Check that differentiating functions with ZST args does not break
|
||||||
|
|
||||||
|
#![feature(autodiff)]
|
||||||
|
|
||||||
|
#[core::autodiff::autodiff_forward(fd_inner, Const, Dual)]
|
||||||
|
fn f(_zst: (), _x: &mut f64) {}
|
||||||
|
|
||||||
|
fn fd(x: &mut f64, xd: &mut f64) {
|
||||||
|
fd_inner((), x, xd);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {}
|
||||||
Reference in New Issue
Block a user