Auto merge of #142544 - Sa4dUs:prevent-abi-changes, r=ZuseZ4
Prevent ABI changes affect EnzymeAD This PR handles ABI changes for autodiff input arguments to improve Enzyme compatibility. Fundamentally this adjusts activities when a function argument is lowered as an `ScalarPair`, so there's no mismatch between diff activities and args. Also removes activities corresponding to ZSTs. fixes: https://github.com/rust-lang/rust/issues/144025 r? `@ZuseZ4`
This commit is contained in:
@@ -3,8 +3,9 @@ use std::ptr;
|
||||
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode};
|
||||
use rustc_codegen_ssa::common::TypeKind;
|
||||
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
|
||||
use rustc_middle::ty::{PseudoCanonicalInput, Ty, TyCtxt, TypingEnv};
|
||||
use rustc_middle::ty::{Instance, PseudoCanonicalInput, TyCtxt, TypingEnv};
|
||||
use rustc_middle::{bug, ty};
|
||||
use rustc_target::callconv::PassMode;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::builder::{Builder, PlaceRef, UNNAMED};
|
||||
@@ -16,9 +17,12 @@ use crate::value::Value;
|
||||
|
||||
pub(crate) fn adjust_activity_to_abi<'tcx>(
|
||||
tcx: TyCtxt<'tcx>,
|
||||
fn_ty: Ty<'tcx>,
|
||||
instance: Instance<'tcx>,
|
||||
typing_env: TypingEnv<'tcx>,
|
||||
da: &mut Vec<DiffActivity>,
|
||||
) {
|
||||
let fn_ty = instance.ty(tcx, typing_env);
|
||||
|
||||
if !matches!(fn_ty.kind(), ty::FnDef(..)) {
|
||||
bug!("expected fn def for autodiff, got {:?}", fn_ty);
|
||||
}
|
||||
@@ -27,8 +31,16 @@ pub(crate) fn adjust_activity_to_abi<'tcx>(
|
||||
// All we do is decide how to handle the arguments.
|
||||
let sig = fn_ty.fn_sig(tcx).skip_binder();
|
||||
|
||||
// FIXME(Sa4dUs): pass proper varargs once we have support for differentiating variadic functions
|
||||
let Ok(fn_abi) =
|
||||
tcx.fn_abi_of_instance(typing_env.as_query_input((instance, ty::List::empty())))
|
||||
else {
|
||||
bug!("failed to get fn_abi of instance with empty varargs");
|
||||
};
|
||||
|
||||
let mut new_activities = vec![];
|
||||
let mut new_positions = vec![];
|
||||
let mut del_activities = 0;
|
||||
for (i, ty) in sig.inputs().iter().enumerate() {
|
||||
if let Some(inner_ty) = ty.builtin_deref(true) {
|
||||
if inner_ty.is_slice() {
|
||||
@@ -80,6 +92,34 @@ pub(crate) fn adjust_activity_to_abi<'tcx>(
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
let pci = PseudoCanonicalInput { typing_env: TypingEnv::fully_monomorphized(), value: *ty };
|
||||
|
||||
let layout = match tcx.layout_of(pci) {
|
||||
Ok(layout) => layout.layout,
|
||||
Err(_) => {
|
||||
bug!("failed to compute layout for type {:?}", ty);
|
||||
}
|
||||
};
|
||||
|
||||
let pass_mode = &fn_abi.args[i].mode;
|
||||
|
||||
// For ZST, just ignore and don't add its activity, as this arg won't be present
|
||||
// in the LLVM passed to Enzyme.
|
||||
// Some targets pass ZST indirectly in the C ABI, in that case, handle it as a normal arg
|
||||
// FIXME(Sa4dUs): Enforce ZST corresponding diff activity be `Const`
|
||||
if *pass_mode == PassMode::Ignore {
|
||||
del_activities += 1;
|
||||
da.remove(i);
|
||||
}
|
||||
|
||||
// If the argument is lowered as a `ScalarPair`, we need to duplicate its activity.
|
||||
// Otherwise, the number of activities won't match the number of LLVM arguments and
|
||||
// this will lead to errors when verifying the Enzyme call.
|
||||
if let rustc_abi::BackendRepr::ScalarPair(_, _) = layout.backend_repr() {
|
||||
new_activities.push(da[i].clone());
|
||||
new_positions.push(i + 1 - del_activities);
|
||||
}
|
||||
}
|
||||
// now add the extra activities coming from slices
|
||||
// Reverse order to not invalidate the indices
|
||||
|
||||
@@ -1208,7 +1208,8 @@ fn codegen_autodiff<'ll, 'tcx>(
|
||||
|
||||
adjust_activity_to_abi(
|
||||
tcx,
|
||||
fn_source.ty(tcx, TypingEnv::fully_monomorphized()),
|
||||
fn_source,
|
||||
TypingEnv::fully_monomorphized(),
|
||||
&mut diff_attrs.input_activity,
|
||||
);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user