remove noinline attribute and add alwaysinline after AD pass

This commit is contained in:
bit-aloo
2025-04-20 15:54:12 +05:30
parent 9bc04016e6
commit 7018392337
8 changed files with 104 additions and 10 deletions

View File

@@ -1,5 +1,4 @@
//! Set and unset common attributes on LLVM values.
use rustc_attr_parsing::{InlineAttr, InstructionSetAttr, OptimizeAttr};
use rustc_codegen_ssa::traits::*;
use rustc_hir::def_id::DefId;
@@ -32,7 +31,7 @@ pub(crate) fn has_attr(llfn: &Value, idx: AttributePlace, attr: AttributeKind) -
llvm::HasAttributeAtIndex(llfn, idx, attr)
}
pub(crate) fn has_string_attr(llfn: &Value, name: *const i8) -> bool {
pub(crate) fn has_string_attr(llfn: &Value, name: &str) -> bool {
llvm::HasStringAttribute(llfn, name)
}
@@ -40,7 +39,7 @@ pub(crate) fn remove_from_llfn(llfn: &Value, place: AttributePlace, kind: Attrib
llvm::RemoveRustEnumAttributeAtIndex(llfn, place, kind);
}
pub(crate) fn remove_string_attr_from_llfn(llfn: &Value, name: *const i8) {
pub(crate) fn remove_string_attr_from_llfn(llfn: &Value, name: &str) {
llvm::RemoveStringAttrFromFn(llfn, name);
}

View File

@@ -28,8 +28,9 @@ use crate::back::write::{
use crate::errors::{
DynamicLinkingWithLTO, LlvmError, LtoBitcodeFromRlib, LtoDisallowed, LtoDylib, LtoProcMacro,
};
use crate::llvm::AttributePlace::Function;
use crate::llvm::{self, build_string};
use crate::{LlvmCodegenBackend, ModuleLlvm};
use crate::{LlvmCodegenBackend, ModuleLlvm, SimpleCx, attributes};
/// We keep track of the computed LTO cache keys from the previous
/// session to determine which CGUs we can reuse.
@@ -666,6 +667,31 @@ pub(crate) fn run_pass_manager(
}
if cfg!(llvm_enzyme) && enable_ad && !thin {
let cx =
SimpleCx::new(module.module_llvm.llmod(), &module.module_llvm.llcx, cgcx.pointer_size);
for function in cx.get_functions() {
let enzyme_marker = "enzyme_marker";
if attributes::has_string_attr(function, enzyme_marker) {
// Sanity check: Ensure 'noinline' is present before replacing it.
assert!(
!attributes::has_attr(function, Function, llvm::AttributeKind::NoInline),
"Expected __enzyme function to have 'noinline' before adding 'alwaysinline'"
);
attributes::remove_from_llfn(function, Function, llvm::AttributeKind::NoInline);
attributes::remove_string_attr_from_llfn(function, enzyme_marker);
assert!(
!attributes::has_string_attr(function, enzyme_marker),
"Expected function to not have 'enzyme_marker'"
);
let always_inline = llvm::AttributeKind::AlwaysInline.create_attr(cx.llcx);
attributes::apply_to_llfn(function, Function, &[always_inline]);
}
}
let opt_stage = llvm::OptStage::FatLTO;
let stage = write::AutodiffStage::PostAD;
if !config.autodiff.contains(&config::AutoDiff::NoPostopt) {

View File

@@ -698,6 +698,16 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
llvm::LLVMMDStringInContext2(self.llcx(), name.as_ptr() as *const c_char, name.len())
})
}
pub(crate) fn get_functions(&self) -> Vec<&'ll Value> {
let mut functions = vec![];
let mut func = unsafe { llvm::LLVMGetFirstFunction(self.llmod()) };
while let Some(f) = func {
functions.push(f);
func = unsafe { llvm::LLVMGetNextFunction(f) }
}
functions
}
}
impl<'ll, 'tcx> MiscCodegenMethods<'tcx> for CodegenCx<'ll, 'tcx> {

View File

@@ -19,8 +19,12 @@ unsafe extern "C" {
pub(crate) fn LLVMRustVerifyFunction(V: &Value, action: LLVMRustVerifierFailureAction) -> Bool;
pub(crate) fn LLVMRustHasAttributeAtIndex(V: &Value, i: c_uint, Kind: AttributeKind) -> bool;
pub(crate) fn LLVMRustGetArrayNumElements(Ty: &Type) -> u64;
pub(crate) fn LLVMRustHasFnAttribute(F: &Value, Name: *const c_char) -> bool;
pub(crate) fn LLVMRustRemoveFnAttribute(F: &Value, Name: *const c_char);
pub(crate) fn LLVMRustHasFnAttribute(
F: &Value,
Name: *const c_char,
NameLen: libc::size_t,
) -> bool;
pub(crate) fn LLVMRustRemoveFnAttribute(F: &Value, Name: *const c_char, NameLen: libc::size_t);
pub(crate) fn LLVMGetFirstFunction(M: &Module) -> Option<&Value>;
pub(crate) fn LLVMGetNextFunction(Fn: &Value) -> Option<&Value>;
pub(crate) fn LLVMRustRemoveEnumAttributeAtIndex(

View File

@@ -41,6 +41,32 @@ pub(crate) fn AddFunctionAttributes<'ll>(
}
}
pub(crate) fn HasAttributeAtIndex<'ll>(
llfn: &'ll Value,
idx: AttributePlace,
kind: AttributeKind,
) -> bool {
unsafe { LLVMRustHasAttributeAtIndex(llfn, idx.as_uint(), kind) }
}
pub(crate) fn HasStringAttribute<'ll>(llfn: &'ll Value, name: &str) -> bool {
unsafe { LLVMRustHasFnAttribute(llfn, name.as_c_char_ptr(), name.len()) }
}
pub(crate) fn RemoveStringAttrFromFn<'ll>(llfn: &'ll Value, name: &str) {
unsafe { LLVMRustRemoveFnAttribute(llfn, name.as_c_char_ptr(), name.len()) }
}
pub(crate) fn RemoveRustEnumAttributeAtIndex(
llfn: &Value,
place: AttributePlace,
kind: AttributeKind,
) {
unsafe {
LLVMRustRemoveEnumAttributeAtIndex(llfn, place.as_uint(), kind);
}
}
pub(crate) fn AddCallSiteAttributes<'ll>(
callsite: &'ll Value,
idx: AttributePlace,

View File

@@ -128,6 +128,10 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
(**self).borrow().llcx
}
pub(crate) fn llmod(&self) -> &'ll llvm::Module {
(**self).borrow().llmod
}
pub(crate) fn isize_ty(&self) -> &'ll Type {
(**self).borrow().isize_ty
}